Python tensorflow.python.training.saver.Saver() Examples
The following are 30
code examples of tensorflow.python.training.saver.Saver().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.training.saver
, or try the search function
.
Example #1
Source File: evaluation_test.py From tf-slim with Apache License 2.0 | 6 votes |
def testReturnsSingleCheckpointIfOneCheckpointFound(self): checkpoint_dir = tempfile.mkdtemp('one_checkpoint_found') if not gfile.Exists(checkpoint_dir): gfile.MakeDirs(checkpoint_dir) global_step = variables.get_or_create_global_step() saver = saver_lib.Saver() # Saves the global step. with self.cached_session() as session: session.run(variables_lib.global_variables_initializer()) save_path = os.path.join(checkpoint_dir, 'model.ckpt') saver.save(session, save_path, global_step=global_step) num_found = 0 for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0): num_found += 1 self.assertEqual(num_found, 1)
Example #2
Source File: exporter.py From ros_people_object_detection_tensorflow with Apache License 2.0 | 6 votes |
def replace_variable_values_with_moving_averages(graph, current_checkpoint_file, new_checkpoint_file): """Replaces variable values in the checkpoint with their moving averages. If the current checkpoint has shadow variables maintaining moving averages of the variables defined in the graph, this function generates a new checkpoint where the variables contain the values of their moving averages. Args: graph: a tf.Graph object. current_checkpoint_file: a checkpoint containing both original variables and their moving averages. new_checkpoint_file: file path to write a new checkpoint. """ with graph.as_default(): variable_averages = tf.train.ExponentialMovingAverage(0.0) ema_variables_to_restore = variable_averages.variables_to_restore() with tf.Session() as sess: read_saver = tf.train.Saver(ema_variables_to_restore) read_saver.restore(sess, current_checkpoint_file) write_saver = tf.train.Saver() write_saver.save(sess, new_checkpoint_file)
Example #3
Source File: exporter.py From DOTA_models with Apache License 2.0 | 6 votes |
def get_frozen_graph_def(inference_graph_def, use_moving_averages, input_checkpoint, output_node_names): """Freezes all variables in a graph definition.""" saver = None if use_moving_averages: variable_averages = tf.train.ExponentialMovingAverage(0.0) variables_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variables_to_restore) else: saver = tf.train.Saver() frozen_graph_def = freeze_graph_with_def_protos( input_graph_def=inference_graph_def, input_saver_def=saver.as_saver_def(), input_checkpoint=input_checkpoint, output_node_names=output_node_names, restore_op_name='save/restore_all', filename_tensor_name='save/Const:0', clear_devices=True, initializer_nodes='') return frozen_graph_def # TODO: Support batch tf example inputs.
Example #4
Source File: exporter.py From vehicle_counting_tensorflow with MIT License | 6 votes |
def replace_variable_values_with_moving_averages(graph, current_checkpoint_file, new_checkpoint_file): """Replaces variable values in the checkpoint with their moving averages. If the current checkpoint has shadow variables maintaining moving averages of the variables defined in the graph, this function generates a new checkpoint where the variables contain the values of their moving averages. Args: graph: a tf.Graph object. current_checkpoint_file: a checkpoint containing both original variables and their moving averages. new_checkpoint_file: file path to write a new checkpoint. """ with graph.as_default(): variable_averages = tf.train.ExponentialMovingAverage(0.0) ema_variables_to_restore = variable_averages.variables_to_restore() with tf.Session() as sess: read_saver = tf.train.Saver(ema_variables_to_restore) read_saver.restore(sess, current_checkpoint_file) write_saver = tf.train.Saver() write_saver.save(sess, new_checkpoint_file)
Example #5
Source File: evaluation_test.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def testRestoredModelPerformance(self): checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt') log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/') # First, save out the current model to a checkpoint: init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) with self.test_session() as sess: sess.run(init_op) saver.save(sess, checkpoint_path) # Next, determine the metric to evaluate: value_op, update_op = metric_ops.streaming_accuracy(self._predictions, self._labels) # Run the evaluation and verify the results: accuracy_value = evaluation.evaluate_once( '', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op) self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
Example #6
Source File: learning_test.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def testTrainWithNoneAsLogdirWhenUsingSaverRaisesError(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) tf_predictions = LogisticClassifier(tf_inputs) loss_ops.log_loss(tf_predictions, tf_labels) total_loss = loss_ops.get_total_loss() optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) train_op = learning.create_train_op(total_loss, optimizer) saver = saver_lib.Saver() with self.assertRaises(ValueError): learning.train( train_op, None, init_op=None, number_of_steps=300, saver=saver)
Example #7
Source File: convert_ckpt_to_pb.py From SSH-TensorFlow with MIT License | 6 votes |
def replace_variable_values_with_moving_averages(graph, current_checkpoint_file, new_checkpoint_file): """Replaces variable values in the checkpoint with their moving averages. If the current checkpoint has shadow variables maintaining moving averages of the variables defined in the graph, this function generates a new checkpoint where the variables contain the values of their moving averages. Args: graph: a tf.Graph object. current_checkpoint_file: a checkpoint containing both original variables and their moving averages. new_checkpoint_file: file path to write a new checkpoint. """ with graph.as_default(): variable_averages = tf.train.ExponentialMovingAverage(0.0) ema_variables_to_restore = variable_averages.variables_to_restore() with tf.Session() as sess: read_saver = tf.train.Saver(ema_variables_to_restore) read_saver.restore(sess, current_checkpoint_file) write_saver = tf.train.Saver() write_saver.save(sess, new_checkpoint_file)
Example #8
Source File: exporter.py From Person-Detection-and-Tracking with MIT License | 6 votes |
def replace_variable_values_with_moving_averages(graph, current_checkpoint_file, new_checkpoint_file): """Replaces variable values in the checkpoint with their moving averages. If the current checkpoint has shadow variables maintaining moving averages of the variables defined in the graph, this function generates a new checkpoint where the variables contain the values of their moving averages. Args: graph: a tf.Graph object. current_checkpoint_file: a checkpoint containing both original variables and their moving averages. new_checkpoint_file: file path to write a new checkpoint. """ with graph.as_default(): variable_averages = tf.train.ExponentialMovingAverage(0.0) ema_variables_to_restore = variable_averages.variables_to_restore() with tf.Session() as sess: read_saver = tf.train.Saver(ema_variables_to_restore) read_saver.restore(sess, current_checkpoint_file) write_saver = tf.train.Saver() write_saver.save(sess, new_checkpoint_file)
Example #9
Source File: exporter.py From multi_task_test with Apache License 2.0 | 6 votes |
def replace_variable_values_with_moving_averages(graph, current_checkpoint_file, new_checkpoint_file): """Replaces variable values in the checkpoint with their moving averages. If the current checkpoint has shadow variables maintaining moving averages of the variables defined in the graph, this function generates a new checkpoint where the variables contain the values of their moving averages. Args: graph: A tf.Graph object. current_checkpoint_file: A checkpoint both original variables and their moving averages. new_checkpoint_file: File path to write a new checkpoint. """ with graph.as_default(): variable_averages = tf.train.ExponentialMovingAverage(0.0) ema_variables_to_restore = variable_averages.variables_to_restore() with tf.Session() as sess: read_saver = tf.train.Saver(ema_variables_to_restore) read_saver.restore(sess, current_checkpoint_file) write_saver = tf.train.Saver() write_saver.save(sess, new_checkpoint_file)
Example #10
Source File: exporter.py From cartoonify with MIT License | 6 votes |
def replace_variable_values_with_moving_averages(graph, current_checkpoint_file, new_checkpoint_file): """Replaces variable values in the checkpoint with their moving averages. If the current checkpoint has shadow variables maintaining moving averages of the variables defined in the graph, this function generates a new checkpoint where the variables contain the values of their moving averages. Args: graph: a tf.Graph object. current_checkpoint_file: a checkpoint containing both original variables and their moving averages. new_checkpoint_file: file path to write a new checkpoint. """ with graph.as_default(): variable_averages = tf.train.ExponentialMovingAverage(0.0) ema_variables_to_restore = variable_averages.variables_to_restore() with tf.Session() as sess: read_saver = tf.train.Saver(ema_variables_to_restore) read_saver.restore(sess, current_checkpoint_file) write_saver = tf.train.Saver() write_saver.save(sess, new_checkpoint_file)
Example #11
Source File: exporter.py From finetune_classification with Apache License 2.0 | 6 votes |
def replace_variable_values_with_moving_averages(graph, current_checkpoint_file, new_checkpoint_file): """Replaces variable values in the checkpoint with their moving averages. If the current checkpoint has shadow variables maintaining moving averages of the variables defined in the graph, this function generates a new checkpoint where the variables contain the values of their moving averages. Args: graph: A tf.Graph object. current_checkpoint_file: A checkpoint both original variables and their moving averages. new_checkpoint_file: File path to write a new checkpoint. """ with graph.as_default(): variable_averages = tf.train.ExponentialMovingAverage(0.0) ema_variables_to_restore = variable_averages.variables_to_restore() with tf.Session() as sess: read_saver = tf.train.Saver(ema_variables_to_restore) read_saver.restore(sess, current_checkpoint_file) write_saver = tf.train.Saver() write_saver.save(sess, new_checkpoint_file)
Example #12
Source File: exporter.py From yolo_v2 with Apache License 2.0 | 6 votes |
def replace_variable_values_with_moving_averages(graph, current_checkpoint_file, new_checkpoint_file): """Replaces variable values in the checkpoint with their moving averages. If the current checkpoint has shadow variables maintaining moving averages of the variables defined in the graph, this function generates a new checkpoint where the variables contain the values of their moving averages. Args: graph: a tf.Graph object. current_checkpoint_file: a checkpoint containing both original variables and their moving averages. new_checkpoint_file: file path to write a new checkpoint. """ with graph.as_default(): variable_averages = tf.train.ExponentialMovingAverage(0.0) ema_variables_to_restore = variable_averages.variables_to_restore() with tf.Session() as sess: read_saver = tf.train.Saver(ema_variables_to_restore) read_saver.restore(sess, current_checkpoint_file) write_saver = tf.train.Saver() write_saver.save(sess, new_checkpoint_file)
Example #13
Source File: exporter.py From Traffic-Rule-Violation-Detection-System with MIT License | 6 votes |
def replace_variable_values_with_moving_averages(graph, current_checkpoint_file, new_checkpoint_file): """Replaces variable values in the checkpoint with their moving averages. If the current checkpoint has shadow variables maintaining moving averages of the variables defined in the graph, this function generates a new checkpoint where the variables contain the values of their moving averages. Args: graph: a tf.Graph object. current_checkpoint_file: a checkpoint containing both original variables and their moving averages. new_checkpoint_file: file path to write a new checkpoint. """ with graph.as_default(): variable_averages = tf.train.ExponentialMovingAverage(0.0) ema_variables_to_restore = variable_averages.variables_to_restore() with tf.Session() as sess: read_saver = tf.train.Saver(ema_variables_to_restore) read_saver.restore(sess, current_checkpoint_file) write_saver = tf.train.Saver() write_saver.save(sess, new_checkpoint_file)
Example #14
Source File: variables_test.py From tf-slim with Apache License 2.0 | 6 votes |
def create_checkpoint_from_values(self, var_names_to_values, checkpoint_dir, global_step=None): """Creates a checkpoint from a mapping of name to values in model_dir. Args: var_names_to_values: a map from variable names to values. checkpoint_dir: the directory where the checkpoint will be saved. global_step: the global step used to save the checkpoint. Returns: the model_path to the checkpoint. """ var_list = [] with session.Session('', graph=ops.Graph()) as sess: # Create a set of variables to save in the checkpoint. for var_name in var_names_to_values: var_value = var_names_to_values[var_name] var_list.append(variables_lib.VariableV1(var_value, name=var_name)) saver = saver_lib.Saver(var_list) init_op = variables_lib.variables_initializer(var_list) sess.run(init_op) # Save the initialized values in the file at 'checkpoint_dir' return saver.save(sess, checkpoint_dir, global_step=global_step)
Example #15
Source File: variables_test.py From tf-slim with Apache License 2.0 | 6 votes |
def create_checkpoint_from_values(self, var_names_to_values, checkpoint_dir, global_step=None): """Creates a checkpoint from a mapping of name to values in model_dir. Args: var_names_to_values: a map from variable names to values. checkpoint_dir: the directory where the checkpoint will be saved. global_step: the global step used to save the checkpoint. Returns: the model_path to the checkpoint. """ var_list = [] with session.Session('', graph=ops.Graph()) as sess: # Create a set of variables to save in the checkpoint. for var_name in var_names_to_values: var_value = var_names_to_values[var_name] var_list.append(variables_lib.VariableV1(var_value, name=var_name)) saver = saver_lib.Saver(var_list) init_op = variables_lib.variables_initializer(var_list) sess.run(init_op) # Save the initialized values in the file at 'checkpoint_dir' return saver.save(sess, checkpoint_dir, global_step=global_step)
Example #16
Source File: ssd_meta_arch_test.py From tensorflow with BSD 2-Clause "Simplified" License | 5 votes |
def test_restore_map_for_classification_ckpt(self): # Define mock tensorflow classification graph and save variables. test_graph_classification = tf.Graph() with test_graph_classification.as_default(): image = tf.placeholder(dtype=tf.float32, shape=[1, 20, 20, 3]) with tf.variable_scope('mock_model'): net = slim.conv2d(image, num_outputs=32, kernel_size=1, scope='layer1') slim.conv2d(net, num_outputs=3, kernel_size=1, scope='layer2') init_op = tf.global_variables_initializer() saver = tf.train.Saver() save_path = self.get_temp_dir() with self.test_session() as sess: sess.run(init_op) saved_model_path = saver.save(sess, save_path) # Create tensorflow detection graph and load variables from # classification checkpoint. test_graph_detection = tf.Graph() with test_graph_detection.as_default(): inputs_shape = [2, 2, 2, 3] inputs = tf.to_float(tf.random_uniform( inputs_shape, minval=0, maxval=255, dtype=tf.int32)) preprocessed_inputs = self._model.preprocess(inputs) prediction_dict = self._model.predict(preprocessed_inputs) self._model.postprocess(prediction_dict) var_map = self._model.restore_map(from_detection_checkpoint=False) self.assertIsInstance(var_map, dict) saver = tf.train.Saver(var_map) with self.test_session() as sess: saver.restore(sess, saved_model_path) for var in sess.run(tf.report_uninitialized_variables()): self.assertNotIn('FeatureExtractor', var.name)
Example #17
Source File: ssd_meta_arch_test.py From HereIsWally with MIT License | 5 votes |
def test_restore_fn_classification(self): # Define mock tensorflow classification graph and save variables. test_graph_classification = tf.Graph() with test_graph_classification.as_default(): image = tf.placeholder(dtype=tf.float32, shape=[1, 20, 20, 3]) with tf.variable_scope('mock_model'): net = slim.conv2d(image, num_outputs=32, kernel_size=1, scope='layer1') slim.conv2d(net, num_outputs=3, kernel_size=1, scope='layer2') init_op = tf.global_variables_initializer() saver = tf.train.Saver() save_path = self.get_temp_dir() with self.test_session() as sess: sess.run(init_op) saved_model_path = saver.save(sess, save_path) # Create tensorflow detection graph and load variables from # classification checkpoint. test_graph_detection = tf.Graph() with test_graph_detection.as_default(): inputs_shape = [2, 2, 2, 3] inputs = tf.to_float(tf.random_uniform( inputs_shape, minval=0, maxval=255, dtype=tf.int32)) preprocessed_inputs = self._model.preprocess(inputs) prediction_dict = self._model.predict(preprocessed_inputs) self._model.postprocess(prediction_dict) restore_fn = self._model.restore_fn(saved_model_path, from_detection_checkpoint=False) with self.test_session() as sess: restore_fn(sess) for var in sess.run(tf.report_uninitialized_variables()): self.assertNotIn('FeatureExtractor', var.name)
Example #18
Source File: ssd_meta_arch_test.py From HereIsWally with MIT License | 5 votes |
def test_restore_fn_detection(self): init_op = tf.global_variables_initializer() saver = tf_saver.Saver() save_path = self.get_temp_dir() with self.test_session() as sess: sess.run(init_op) saved_model_path = saver.save(sess, save_path) restore_fn = self._model.restore_fn(saved_model_path, from_detection_checkpoint=True) restore_fn(sess) for var in sess.run(tf.report_uninitialized_variables()): self.assertNotIn('FeatureExtractor', var.name)
Example #19
Source File: ssd_meta_arch_test.py From garbage-object-detection-tensorflow with MIT License | 5 votes |
def test_restore_map_for_classification_ckpt(self): # Define mock tensorflow classification graph and save variables. test_graph_classification = tf.Graph() with test_graph_classification.as_default(): image = tf.placeholder(dtype=tf.float32, shape=[1, 20, 20, 3]) with tf.variable_scope('mock_model'): net = slim.conv2d(image, num_outputs=32, kernel_size=1, scope='layer1') slim.conv2d(net, num_outputs=3, kernel_size=1, scope='layer2') init_op = tf.global_variables_initializer() saver = tf.train.Saver() save_path = self.get_temp_dir() with self.test_session() as sess: sess.run(init_op) saved_model_path = saver.save(sess, save_path) # Create tensorflow detection graph and load variables from # classification checkpoint. test_graph_detection = tf.Graph() with test_graph_detection.as_default(): inputs_shape = [2, 2, 2, 3] inputs = tf.to_float(tf.random_uniform( inputs_shape, minval=0, maxval=255, dtype=tf.int32)) preprocessed_inputs = self._model.preprocess(inputs) prediction_dict = self._model.predict(preprocessed_inputs) self._model.postprocess(prediction_dict) var_map = self._model.restore_map(from_detection_checkpoint=False) self.assertIsInstance(var_map, dict) saver = tf.train.Saver(var_map) with self.test_session() as sess: saver.restore(sess, saved_model_path) for var in sess.run(tf.report_uninitialized_variables()): self.assertNotIn('FeatureExtractor', var.name)
Example #20
Source File: ssd_meta_arch_test.py From garbage-object-detection-tensorflow with MIT License | 5 votes |
def test_restore_map_for_detection_ckpt(self): init_op = tf.global_variables_initializer() saver = tf_saver.Saver() save_path = self.get_temp_dir() with self.test_session() as sess: sess.run(init_op) saved_model_path = saver.save(sess, save_path) var_map = self._model.restore_map(from_detection_checkpoint=True) self.assertIsInstance(var_map, dict) saver = tf.train.Saver(var_map) saver.restore(sess, saved_model_path) for var in sess.run(tf.report_uninitialized_variables()): self.assertNotIn('FeatureExtractor', var.name)
Example #21
Source File: exporter.py From garbage-object-detection-tensorflow with MIT License | 5 votes |
def _write_graph_and_checkpoint(inference_graph_def, model_path, input_saver_def, trained_checkpoint_prefix): for node in inference_graph_def.node: node.device = '' with tf.Graph().as_default(): tf.import_graph_def(inference_graph_def, name='') with session.Session() as sess: saver = saver_lib.Saver(saver_def=input_saver_def, save_relative_paths=True) saver.restore(sess, trained_checkpoint_prefix) saver.save(sess, model_path)
Example #22
Source File: exporter.py From finetune_classification with Apache License 2.0 | 5 votes |
def write_graph_and_checkpoint(inference_graph_def, model_path, input_saver_def, trained_checkpoint_prefix): """Writes the graph and the checkpoint into disk.""" for node in inference_graph_def.node: node.device = '' with tf.Graph().as_default(): tf.import_graph_def(inference_graph_def, name='') with session.Session() as sess: saver = saver_lib.Saver(saver_def=input_saver_def, save_relative_paths=True) saver.restore(sess, trained_checkpoint_prefix) saver.save(sess, model_path)
Example #23
Source File: ssd_meta_arch_test.py From object_detector_app with MIT License | 5 votes |
def test_restore_fn_classification(self): # Define mock tensorflow classification graph and save variables. test_graph_classification = tf.Graph() with test_graph_classification.as_default(): image = tf.placeholder(dtype=tf.float32, shape=[1, 20, 20, 3]) with tf.variable_scope('mock_model'): net = slim.conv2d(image, num_outputs=32, kernel_size=1, scope='layer1') slim.conv2d(net, num_outputs=3, kernel_size=1, scope='layer2') init_op = tf.global_variables_initializer() saver = tf.train.Saver() save_path = self.get_temp_dir() with self.test_session() as sess: sess.run(init_op) saved_model_path = saver.save(sess, save_path) # Create tensorflow detection graph and load variables from # classification checkpoint. test_graph_detection = tf.Graph() with test_graph_detection.as_default(): inputs_shape = [2, 2, 2, 3] inputs = tf.to_float(tf.random_uniform( inputs_shape, minval=0, maxval=255, dtype=tf.int32)) preprocessed_inputs = self._model.preprocess(inputs) prediction_dict = self._model.predict(preprocessed_inputs) self._model.postprocess(prediction_dict) restore_fn = self._model.restore_fn(saved_model_path, from_detection_checkpoint=False) with self.test_session() as sess: restore_fn(sess) for var in sess.run(tf.report_uninitialized_variables()): self.assertNotIn('FeatureExtractor', var.name)
Example #24
Source File: exporter.py From cartoonify with MIT License | 5 votes |
def _write_graph_and_checkpoint(inference_graph_def, model_path, input_saver_def, trained_checkpoint_prefix): for node in inference_graph_def.node: node.device = '' with tf.Graph().as_default(): tf.import_graph_def(inference_graph_def, name='') with session.Session() as sess: saver = saver_lib.Saver(saver_def=input_saver_def, save_relative_paths=True) saver.restore(sess, trained_checkpoint_prefix) saver.save(sess, model_path)
Example #25
Source File: freeze_graph.py From deepnlp with MIT License | 5 votes |
def _parse_input_saver_proto(input_saver, input_binary): """Parser input tensorflow Saver into SaverDef proto.""" if not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 mode = "rb" if input_binary else "r" with gfile.FastGFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) return saver_def
Example #26
Source File: exporter.py From multi_task_test with Apache License 2.0 | 5 votes |
def write_graph_and_checkpoint(inference_graph_def, model_path, input_saver_def, trained_checkpoint_prefix): """Writes the graph and the checkpoint into disk.""" for node in inference_graph_def.node: node.device = '' with tf.Graph().as_default(): tf.import_graph_def(inference_graph_def, name='') with session.Session() as sess: saver = saver_lib.Saver(saver_def=input_saver_def, save_relative_paths=True) saver.restore(sess, trained_checkpoint_prefix) saver.save(sess, model_path)
Example #27
Source File: freeze_graph.py From deepnlp with MIT License | 5 votes |
def _parse_input_saver_proto(input_saver, input_binary): """Parser input tensorflow Saver into SaverDef proto.""" if not gfile.Exists(input_saver): print("Input saver file '" + input_saver + "' does not exist!") return -1 mode = "rb" if input_binary else "r" with gfile.FastGFile(input_saver, mode) as f: saver_def = saver_pb2.SaverDef() if input_binary: saver_def.ParseFromString(f.read()) else: text_format.Merge(f.read(), saver_def) return saver_def
Example #28
Source File: exporter.py From Person-Detection-and-Tracking with MIT License | 5 votes |
def write_graph_and_checkpoint(inference_graph_def, model_path, input_saver_def, trained_checkpoint_prefix): """Writes the graph and the checkpoint into disk.""" for node in inference_graph_def.node: node.device = '' with tf.Graph().as_default(): tf.import_graph_def(inference_graph_def, name='') with session.Session() as sess: saver = saver_lib.Saver(saver_def=input_saver_def, save_relative_paths=True) saver.restore(sess, trained_checkpoint_prefix) saver.save(sess, model_path)
Example #29
Source File: convert_ckpt_to_pb.py From SSH-TensorFlow with MIT License | 5 votes |
def write_graph_and_checkpoint(inference_graph_def, model_path, input_saver_def, trained_checkpoint_prefix): """Writes the graph and the checkpoint into disk.""" for node in inference_graph_def.node: node.device = '' with tf.Graph().as_default(): tf.import_graph_def(inference_graph_def, name='') with session.Session() as sess: saver = saver_lib.Saver(saver_def=input_saver_def, save_relative_paths=True) saver.restore(sess, trained_checkpoint_prefix) saver.save(sess, model_path)
Example #30
Source File: ssd_meta_arch_test.py From object_detector_app with MIT License | 5 votes |
def test_restore_fn_detection(self): init_op = tf.global_variables_initializer() saver = tf_saver.Saver() save_path = self.get_temp_dir() with self.test_session() as sess: sess.run(init_op) saved_model_path = saver.save(sess, save_path) restore_fn = self._model.restore_fn(saved_model_path, from_detection_checkpoint=True) restore_fn(sess) for var in sess.run(tf.report_uninitialized_variables()): self.assertNotIn('FeatureExtractor', var.name)