Python tensorflow.python.pywrap_tensorflow.NewCheckpointReader() Examples

The following are 30 code examples of tensorflow.python.pywrap_tensorflow.NewCheckpointReader(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.pywrap_tensorflow , or try the search function .
Example #1
Source File: tf_policy_network.py    From crossgap_il_rl with GNU General Public License v2.0 6 votes vote down vote up
def resort_para_form_checkpoint(self, _ckpt_name_vec, graph, sess):
        # with tf.name_scope("restore"):
        if( isinstance(_ckpt_name_vec, list)):
            ckpt_name_vec = _ckpt_name_vec
        else:
            ckpt_name_vec = [_ckpt_name_vec]

        with tf.name_scope ("restore"):
            for ckpt_name in ckpt_name_vec:
                print("===== Restore data from %s =====" % ckpt_name)
                reader = pywrap_tensorflow.NewCheckpointReader(ckpt_name)
                var_to_shape_map = reader.get_variable_to_shape_map()
                for key in var_to_shape_map:
                    # print("tensor_name: ", key)
                    # print(reader.get_tensor(key))
                    # tensor = graph.get_tensor_by_name(key)
                    try:
                        tensor = graph.get_tensor_by_name(key + ":0")
                        sess.run(tf.assign(tensor, reader.get_tensor(key)))
                        # print(tensor)
                    except:
                        # print(key, " can not be restored")
                        pass 
Example #2
Source File: inspect_checkpoint.py    From tf.fashionAI with Apache License 2.0 6 votes vote down vote up
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
    try:
        reader = pywrap_tensorflow.NewCheckpointReader(file_name)
        if all_tensors:
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                print("tensor_name: ", key)
                print(reader.get_tensor(key))
        elif not tensor_name:
            print(reader.debug_string().decode("utf-8"))
        else:
            print("tensor_name: ", tensor_name)
            print(reader.get_tensor(tensor_name))
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print("It's likely that your checkpoint file has been compressed "
                  "with SNAPPY.") 
Example #3
Source File: tf_policy_network.py    From crossgap_il_rl with GNU General Public License v2.0 6 votes vote down vote up
def resore_form_rl_net(self,ckpt_name, graph, sess):
        print("Restore form RL net")

        print("===== Prase data from %s =====" % ckpt_name)
        net_prefix = 'pi/pi'
        reader = pywrap_tensorflow.NewCheckpointReader(ckpt_name)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for _key in var_to_shape_map:
            print(_key)
            # print("tensor_name: ", key)
            # print(reader.get_tensor(key))
            # tensor = graph.get_tensor_by_name(key)
            if (str(_key).startswith('%s/net/'%net_prefix) or
                str(_key).startswith('%s/Trajectory_follower_mlp_net/'%net_prefix)):
                notaion_list =  [m.start() for m in re.finditer('/', _key)]
                key = _key[int(notaion_list[1]+1):len(_key)]+ ":0"
                # print(key)
                try:
                    tensor = graph.get_tensor_by_name(key)
                    sess.run(tf.assign(tensor, reader.get_tensor(_key)))
                    # print(tensor)
                except Exception as e:
                    print(key, " can not be restored, e= ",str(e))
                    pass 
Example #4
Source File: checkpint_inspect.py    From SSD.TensorFlow with Apache License 2.0 6 votes vote down vote up
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
    try:
        reader = pywrap_tensorflow.NewCheckpointReader(file_name)
        if all_tensors:
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                print("tensor_name: ", key)
                print(reader.get_tensor(key))
        elif not tensor_name:
            print(reader.debug_string().decode("utf-8"))
        else:
            print("tensor_name: ", tensor_name)
            print(reader.get_tensor(tensor_name))
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print("It's likely that your checkpoint file has been compressed "
                  "with SNAPPY.") 
Example #5
Source File: checkpint_inspect.py    From inference with Apache License 2.0 6 votes vote down vote up
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
    try:
        reader = pywrap_tensorflow.NewCheckpointReader(file_name)
        if all_tensors:
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                print("tensor_name: ", key)
                print(reader.get_tensor(key))
        elif not tensor_name:
            print(reader.debug_string().decode("utf-8"))
        else:
            print("tensor_name: ", tensor_name)
            print(reader.get_tensor(tensor_name))
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print("It's likely that your checkpoint file has been compressed "
                  "with SNAPPY.") 
Example #6
Source File: npz_file_to_checkpoint.py    From will-people-like-your-image with GNU Lesser General Public License v3.0 6 votes vote down vote up
def create_model_from_npz_file(npz, model, target):
    """Creates a tensorflow model from a given npz structure in which the variables for the desired model are stored.
        npz: Path to the npz structure containing files representing the variables in the model.
        model: Path in which the final model should be stored
        target: A target model which contains the desired names for the structure
    """
    reader = pywrap_tensorflow.NewCheckpointReader(target)
    target_map = reader.get_variable_to_shape_map()

    variables = variables_dictionary_from_npz_file(npz)
    i = 0
    for key in variables:

        if key_contained_in_map(key, target_map):
            name = 'var' + str(i)
            val = tf.Variable(variables[key], name=key)
            exec(name + " = val")
            i += 1

    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        save_path = saver.save(sess, model)
        print("Model saved in file: %s" % save_path) 
Example #7
Source File: SSD512.py    From Object-Detection-API-Tensorflow with MIT License 5 votes vote down vote up
def __init__(self, config, data_provider):
        assert config['mode'] in ['train', 'test']
        assert config['data_format'] in ['channels_first', 'channels_last']
        self.config = config
        self.data_provider = data_provider
        self.input_size = 512
        if config['data_format'] == 'channels_last':
            self.data_shape = [512, 512, 3]
        else:
            self.data_shape = [3, 512, 512]
        self.num_classes = config['num_classes'] + 1
        self.weight_decay = config['weight_decay']
        self.prob = 1. - config['keep_prob']
        self.data_format = config['data_format']
        self.mode = config['mode']
        self.batch_size = config['batch_size'] if config['mode'] == 'train' else 1
        self.nms_score_threshold = config['nms_score_threshold']
        self.nms_max_boxes = config['nms_max_boxes']
        self.nms_iou_threshold = config['nms_iou_threshold']
        self.reader = wrap.NewCheckpointReader(config['pretraining_weight'])

        if self.mode == 'train':
            self.num_train = data_provider['num_train']
            self.num_val = data_provider['num_val']
            self.train_generator = data_provider['train_generator']
            self.train_initializer, self.train_iterator = self.train_generator
            if data_provider['val_generator'] is not None:
                self.val_generator = data_provider['val_generator']
                self.val_initializer, self.val_iterator = self.val_generator

        self.global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False)
        self.is_training = True

        self._define_inputs()
        self._build_graph()
        self._create_saver()
        if self.mode == 'train':
            self._create_summary()
        self._init_session() 
Example #8
Source File: npz_file_to_checkpoint.py    From will-people-like-your-image with GNU Lesser General Public License v3.0 5 votes vote down vote up
def print_variables_from_stored_model(graph_path):
    """Prints the names of the tensors stored in a tensorflow model.
        graph_path: path to the stored model.
    """
    reader = pywrap_tensorflow.NewCheckpointReader(graph_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        print("tensor_name: ", key) 
Example #9
Source File: freeze_graph.py    From pynlp with MIT License 5 votes vote down vote up
def freeze_graph(output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
    from tensorflow.python import pywrap_tensorflow

    reader = pywrap_tensorflow.NewCheckpointReader("F:\python_work\siamese-lstm-network\deep-siamese-text-similarity\\atec_runs\\1553238291\checkpoints\model-170000")
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        print("tensor_name: ", key)
    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    output_node_names = "accuracy/temp_sim,output/distance"
    input_checkpoint = "F:\python_work\siamese-lstm-network\deep-siamese-text-similarity\\atec_runs\\1553238291\checkpoints\model-170000.meta"
    model_path = 'F:\python_work\siamese-lstm-network\deep-siamese-text-similarity\\atec_runs\\1553238291\checkpoints\model-170000' # 数据路径

    saver = tf.train.import_meta_graph(input_checkpoint, clear_devices=False)
    graph = tf.get_default_graph()  # 获得默认的图
    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图
    with tf.Session() as sess:
        saver.restore(sess, model_path)  # 恢复图并得到数据
        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出
        print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点 
Example #10
Source File: train.py    From GeetChinese_crack with MIT License 5 votes vote down vote up
def get_variables_in_checkpoint_file(self, file_name):
        try:
            reader = pywrap_tensorflow.NewCheckpointReader(file_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            return var_to_shape_map
        except Exception as e:  # pylint: disable=broad-except
            print(str(e))
            if "corrupted compressed block contents" in str(e):
                print("It's likely that your checkpoint file has been compressed "
                      "with SNAPPY.") 
Example #11
Source File: RefineDet.py    From Object-Detection-API-Tensorflow with MIT License 5 votes vote down vote up
def __init__(self, config, data_provider):
        assert config['mode'] in ['train', 'test']
        assert config['data_format'] in ['channels_first', 'channels_last']
        self.config = config
        self.data_provider = data_provider
        self.input_size = config['input_size']
        if config['data_format'] == 'channels_last':
            self.data_shape = [self.input_size, self.input_size, 3]
        else:
            self.data_shape = [3, self.input_size, self.input_size]
        self.num_classes = config['num_classes'] + 1
        self.weight_decay = config['weight_decay']
        self.prob = 1. - config['keep_prob']
        self.data_format = config['data_format']
        self.mode = config['mode']
        self.batch_size = config['batch_size'] if config['mode'] == 'train' else 1
        self.anchor_ratios = [0.5, 1.0, 2.0]
        self.num_anchors = len(self.anchor_ratios)
        self.nms_score_threshold = config['nms_score_threshold']
        self.nms_max_boxes = config['nms_max_boxes']
        self.nms_iou_threshold = config['nms_iou_threshold']
        self.reader = wrap.NewCheckpointReader(config['pretraining_weight'])

        if self.mode == 'train':
            self.num_train = data_provider['num_train']
            self.num_val = data_provider['num_val']
            self.train_generator = data_provider['train_generator']
            self.train_initializer, self.train_iterator = self.train_generator
            if data_provider['val_generator'] is not None:
                self.val_generator = data_provider['val_generator']
                self.val_initializer, self.val_iterator = self.val_generator

        self.global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False)
        self.is_training = True

        self._define_inputs()
        self._build_graph()
        self._create_saver()
        if self.mode == 'train':
            self._create_summary()
        self._init_session() 
Example #12
Source File: saver.py    From lighttrack with MIT License 5 votes vote down vote up
def get_variables_in_checkpoint_file(file_name):
    try:
        reader = pywrap_tensorflow.NewCheckpointReader(file_name)
        var_to_shape_map = reader.get_variable_to_shape_map()
        return var_to_shape_map
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print(
                "It's likely that your checkpoint file has been compressed "
                "with SNAPPY.") 
Example #13
Source File: SSD300.py    From Object-Detection-API-Tensorflow with MIT License 5 votes vote down vote up
def __init__(self, config, data_provider):
        assert config['mode'] in ['train', 'test']
        assert config['data_format'] in ['channels_first', 'channels_last']
        self.config = config
        self.data_provider = data_provider
        self.input_size = 300
        if config['data_format'] == 'channels_last':
            self.data_shape = [300, 300, 3]
        else:
            self.data_shape = [3, 300, 300]
        self.num_classes = config['num_classes'] + 1
        self.weight_decay = config['weight_decay']
        self.prob = 1. - config['keep_prob']
        self.data_format = config['data_format']
        self.mode = config['mode']
        self.batch_size = config['batch_size'] if config['mode'] == 'train' else 1
        self.nms_score_threshold = config['nms_score_threshold']
        self.nms_max_boxes = config['nms_max_boxes']
        self.nms_iou_threshold = config['nms_iou_threshold']
        self.reader = wrap.NewCheckpointReader(config['pretraining_weight'])

        if self.mode == 'train':
            self.num_train = data_provider['num_train']
            self.num_val = data_provider['num_val']
            self.train_generator = data_provider['train_generator']
            self.train_initializer, self.train_iterator = self.train_generator
            if data_provider['val_generator'] is not None:
                self.val_generator = data_provider['val_generator']
                self.val_initializer, self.val_iterator = self.val_generator

        self.global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False)
        self.is_training = True

        self._define_inputs()
        self._build_graph()
        self._create_saver()
        if self.mode == 'train':
            self._create_summary()
        self._init_session() 
Example #14
Source File: saver.py    From PoseFix_RELEASE with MIT License 5 votes vote down vote up
def get_variables_in_checkpoint_file(file_name):
    try:
        reader = pywrap_tensorflow.NewCheckpointReader(file_name)
        var_to_shape_map = reader.get_variable_to_shape_map()
        return reader, var_to_shape_map
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print(
                "It's likely that your checkpoint file has been compressed "
                "with SNAPPY.") 
Example #15
Source File: checkpint_inspect.py    From SSD.TensorFlow with Apache License 2.0 5 votes vote down vote up
def print_all_tensors_name(file_name):
    try:
        reader = pywrap_tensorflow.NewCheckpointReader(file_name)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in var_to_shape_map:
            print(key)
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print("It's likely that your checkpoint file has been compressed "
                  "with SNAPPY.") 
Example #16
Source File: npz_file_to_checkpoint.py    From will-people-like-your-image with GNU Lesser General Public License v3.0 5 votes vote down vote up
def check_adam(model):
    """Checks whether a unwanted variable of the adam optimizer is still contained in a model.
    """
    reader = pywrap_tensorflow.NewCheckpointReader(model)
    target_map = reader.get_variable_to_shape_map()

    for key in list(target_map):
        if 'Adam_1' in key:
            print(key + "contains Adam.")
    return 
Example #17
Source File: inspect_cp.py    From NJUNMT-tf with Apache License 2.0 5 votes vote down vote up
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors,
                                     all_tensor_names):
  """Prints tensors in a checkpoint file.
  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.
  If `tensor_name` is provided, prints the content of the tensor.
  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
    all_tensors: Boolean indicating whether to print all tensors.
    all_tensor_names: Boolean indicating whether to print all tensor names.
  """
  try:
    reader = pywrap_tensorflow.NewCheckpointReader(file_name)
    if all_tensors or all_tensor_names:
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in sorted(var_to_shape_map):
        print("tensor_name: ", key)
        if all_tensors:
          print(reader.get_tensor(key))
    elif not tensor_name:
      print(reader.debug_string().decode("utf-8"))
    else:
      print("tensor_name: ", tensor_name)
      print(reader.get_tensor(tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.")
    if ("Data loss" in str(e) and
        (any([e in file_name for e in [".index", ".meta", ".data"]]))):
      proposed_file = ".".join(file_name.split(".")[0:-1])
      v2_file_error_template = """
It's likely that this is a V2 checkpoint and you need to provide the filename
*prefix*.  Try removing the '.' and extension.  Try:
inspect checkpoint --file_name = {}"""
      print(v2_file_error_template.format(proposed_file)) 
Example #18
Source File: inspect_checkpoint.py    From MobileNet with Apache License 2.0 5 votes vote down vote up
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
    """Prints tensors in a checkpoint file.

    If no `tensor_name` is provided, prints the tensor names and shapes
    in the checkpoint file.

    If `tensor_name` is provided, prints the content of the tensor.

    Args:
        file_name: Name of the checkpoint file.
        tensor_name: Name of the tensor in the checkpoint file to print.
        all_tensors: Boolean indicating whether to print all tensors.
    """
    try:
        reader = pywrap_tensorflow.NewCheckpointReader(file_name)
        if all_tensors:
            var_to_shape_map = reader.get_variable_to_shape_map()
            for key in var_to_shape_map:
                print("tensor_name: ", key)
                print(reader.get_tensor(key))
        elif not tensor_name:
            print(reader.debug_string().decode("utf-8"))
        else:
            print("tensor_name: ", tensor_name)
            print(reader.get_tensor(tensor_name))
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print("It's likely that your checkpoint file has been compressed "
                  "with SNAPPY.") 
Example #19
Source File: base_network.py    From VAE-GAN with MIT License 5 votes vote down vote up
def load_pretrained_model_weights(self, sess, cfg, network_name, only_bottom=True):
		config_file = get_config(cfg)
		asset_filepath = config_file['assets dir']
		ckpt_path = os.path.join(asset_filepath, config_file["trainer params"].get("checkpoint dir", "checkpoint"))
		ckpt_name = ''
		with open(os.path.join(ckpt_path, 'checkpoint'), 'r') as infile:
			for line in infile:
				if line.startswith('model_checkpoint_path'):
					ckpt_name = line[len("model_checkpoint_path: \""):-2]
		checkpoint_path = os.path.join(ckpt_path, ckpt_name)

		print("Load checkpoint : ", checkpoint_path)
		reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
		var_to_shape_map = reader.get_variable_to_shape_map()

		assign_list = []
		var_list = self.all_vars
		var_dict = {var.name.split(':')[0] : var for var in var_list}

		for key in var_to_shape_map:
			if key.startswith(network_name):
				if only_bottom and 'fc' in key:
					continue
				var_name = self.name + '/' + key[len(network_name)+1:]
				assign_list.append(tf.assign(var_dict[var_name], reader.get_tensor(key)))

		assign_op = tf.group(assign_list)
		sess.run(assign_op)
		return True 
Example #20
Source File: tf_policy_network.py    From crossgap_il_rl with GNU General Public License v2.0 5 votes vote down vote up
def prase_checkpoint_data(self, checkpoint_name):
        print("===== Prase data from %s =====" % checkpoint_name)
        reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_name)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in var_to_shape_map:
            print("tensor_name: ", key)
            # print(reader.get_tensor(key)) 
Example #21
Source File: tf_policy_network.py    From crossgap_il_rl with GNU General Public License v2.0 5 votes vote down vote up
def resort_para_form_checkpoint( prefix, graph, sess):
    # ckpt_name_vec = ["./tf_net/planning_net/tf_saver_252750.ckpt", "./tf_net/control_mlp_net_train/tf_saver_2318100.ckpt"]
    # ckpt_name_vec = ["./tf_net/planning_net/tf_saver_252750.ckpt", "./tf_net/control_mlp_net_train/tf_saver_1300000.ckpt"]
    ckpt_name_vec = ["./tf_net/planning_net/tf_saver_107840000.ckpt", "./tf_net/pid_net/tf_saver_109330000.ckpt"]
    print("=========")
    file = open("full_structure.txt","w")
    file.writelines(str(graph.get_operations()))
    # for ops in tf.Graph.get_all_collection_keys():
    # for ops in graph.get_operations():
    #     file.writelines(ops)
    #     print(ops)
    file.close()
    print("=========")
    with tf.name_scope("restore"):
        for ckpt_name in ckpt_name_vec:
            print("===== Restore data from %s =====" % ckpt_name)
            reader = pywrap_tensorflow.NewCheckpointReader(ckpt_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            for _key in var_to_shape_map:
                # print("tensor_name: ", key)
                # print(reader.get_tensor(key))
                # tensor = graph.get_tensor_by_name(key)
                key = prefix + _key+ ":0"
                try:
                    tensor = graph.get_tensor_by_name(key)
                    sess.run(tf.assign(tensor, reader.get_tensor(_key)))
                    # print(tensor)
                except Exception as e:
                    # print(key, " can not be restored, e= ",str(e))
                    pass 
Example #22
Source File: tf_rapid_trajectory.py    From crossgap_il_rl with GNU General Public License v2.0 5 votes vote down vote up
def resort_para_form_checkpoint( prefix, graph, sess, ckpt_name ):
    from tensorflow.python import pywrap_tensorflow
    # ckpt_name_vec = ["./tf_net/planning_net/tf_saver_252750.ckpt", "./tf_net/control_mlp_net_train/tf_saver_2318100.ckpt"]
    # ckpt_name_vec = ["./tf_net/planning_net/tf_saver_252750.ckpt", "./tf_net/control_mlp_net_train/tf_saver_1300000.ckpt"]
    # ckpt_name_vec = ["./tf_net/planning_net/tf_saver_252750.ckpt", "./tf_net/control_mlp_net/save_net_mlp.ckpt"]
    ckpt_name_vec = [ckpt_name]
    print("=========")
    file = open("full_structure.txt","w")
    file.writelines(str(graph.get_operations()))
    # for ops in tf.Graph.get_all_collection_keys():
    # for ops in graph.get_operations():
    #     file.writelines(ops)
    #     print(ops)
    file.close()
    print("=========")
    with tf.name_scope("restore"):
        for ckpt_name in ckpt_name_vec:
            print("===== Restore data from %s =====" % ckpt_name)
            reader = pywrap_tensorflow.NewCheckpointReader(ckpt_name)
            var_to_shape_map = reader.get_variable_to_shape_map()
            for _key in var_to_shape_map:
                # print("tensor_name: ", key)
                # print(reader.get_tensor(key))
                # tensor = graph.get_tensor_by_name(key)
                key = prefix + _key+ ":0"
                # key = prefix + _key
                try:
                    tensor = graph.get_tensor_by_name(key)
                    sess.run(tf.assign(tensor, reader.get_tensor(_key)))
                    # print(tensor)
                except Exception as e:
                    print(key, " can not be restored, e= ",str(e))
                    pass 
Example #23
Source File: checkpint_inspect.py    From inference with Apache License 2.0 5 votes vote down vote up
def print_all_tensors_name(file_name):
    try:
        reader = pywrap_tensorflow.NewCheckpointReader(file_name)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in var_to_shape_map:
            print(key)
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print("It's likely that your checkpoint file has been compressed "
                  "with SNAPPY.") 
Example #24
Source File: convert_from_depre.py    From tf-faster-rcnn with MIT License 5 votes vote down vote up
def get_variables_in_checkpoint_file(file_name):
  try:
    reader = pywrap_tensorflow.NewCheckpointReader(file_name)
    var_to_shape_map = reader.get_variable_to_shape_map()
    return var_to_shape_map 
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.") 
Example #25
Source File: train_val.py    From tf-faster-rcnn with MIT License 5 votes vote down vote up
def get_variables_in_checkpoint_file(self, file_name):
    try:
      reader = pywrap_tensorflow.NewCheckpointReader(file_name)
      var_to_shape_map = reader.get_variable_to_shape_map()
      return var_to_shape_map 
    except Exception as e:  # pylint: disable=broad-except
      print(str(e))
      if "corrupted compressed block contents" in str(e):
        print("It's likely that your checkpoint file has been compressed "
              "with SNAPPY.") 
Example #26
Source File: average_model.py    From sequencing with MIT License 5 votes vote down vote up
def average_ckpt(checkpoint_from_paths,
                 checkpoint_to_path):
    """Migrates the names of variables within a checkpoint.
    Args:
      checkpoint_from_path: Path to source checkpoint to be read in.
      checkpoint_to_path: Path to checkpoint to be written out.
    """
    with ops.Graph().as_default():
        new_variable_map = defaultdict(list)
        for checkpoint_from_path in checkpoint_from_paths:
            logging.info('Reading checkpoint_from_path %s' % checkpoint_from_path)
            reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_from_path)
            name_shape_map = reader.get_variable_to_shape_map()
            for var_name in name_shape_map:
                tensor = reader.get_tensor(var_name)
                new_variable_map[var_name].append(tensor)

        variable_map = {}
        for var_name in name_shape_map:
            tensor = reduce(lambda x, y: x + y, new_variable_map[var_name]) / len(new_variable_map[var_name])
            var = variables.Variable(tensor, name=var_name)
            variable_map[var_name] = var
      
        print(variable_map)
        saver = saver_lib.Saver(variable_map)
      
        with session.Session() as sess:
          sess.run(variables.global_variables_initializer())
          logging.info('Writing checkpoint_to_path %s' % checkpoint_to_path)
          saver.save(sess, checkpoint_to_path)
    
    logging.info('Summary:')
    logging.info('  Converted %d variable name(s).' % len(new_variable_map)) 
Example #27
Source File: saver.py    From tf-cpn with MIT License 5 votes vote down vote up
def get_variables_in_checkpoint_file(file_name):
    try:
        reader = pywrap_tensorflow.NewCheckpointReader(file_name)
        var_to_shape_map = reader.get_variable_to_shape_map()
        return var_to_shape_map
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print(
                "It's likely that your checkpoint file has been compressed "
                "with SNAPPY.") 
Example #28
Source File: plugin.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _get_reader_for_run(self, run):
    if run in self.readers:
      return self.readers[run]

    config = self._configs[run]
    reader = None
    if config.model_checkpoint_path:
      try:
        reader = NewCheckpointReader(config.model_checkpoint_path)
      except Exception:  # pylint: disable=broad-except
        logging.warning('Failed reading %s', config.model_checkpoint_path)
    self.readers[run] = reader
    return reader 
Example #29
Source File: inspect_checkpoint.py    From tf.fashionAI with Apache License 2.0 5 votes vote down vote up
def print_all_tensors_name(file_name):
    try:
        reader = pywrap_tensorflow.NewCheckpointReader(file_name)
        var_to_shape_map = reader.get_variable_to_shape_map()
        for key in var_to_shape_map:
            print(key)
    except Exception as e:  # pylint: disable=broad-except
        print(str(e))
        if "corrupted compressed block contents" in str(e):
            print("It's likely that your checkpoint file has been compressed "
                  "with SNAPPY.") 
Example #30
Source File: RefineDet.py    From RefineDet-tensorflow with MIT License 5 votes vote down vote up
def __init__(self, config, data_provider):
        assert config['mode'] in ['train', 'test']
        assert config['data_format'] in ['channels_first', 'channels_last']
        self.config = config
        self.data_provider = data_provider
        self.input_size = config['input_size']
        if config['data_format'] == 'channels_last':
            self.data_shape = [self.input_size, self.input_size, 3]
        else:
            self.data_shape = [3, self.input_size, self.input_size]
        self.num_classes = config['num_classes'] + 1
        self.weight_decay = config['weight_decay']
        self.prob = 1. - config['keep_prob']
        self.data_format = config['data_format']
        self.mode = config['mode']
        self.batch_size = config['batch_size'] if config['mode'] == 'train' else 1
        self.anchor_ratios = [0.5, 1.0, 2.0]
        self.num_anchors = len(self.anchor_ratios)
        self.nms_score_threshold = config['nms_score_threshold']
        self.nms_max_boxes = config['nms_max_boxes']
        self.nms_iou_threshold = config['nms_iou_threshold']
        self.reader = wrap.NewCheckpointReader(config['pretraining_weight'])

        if self.mode == 'train':
            self.num_train = data_provider['num_train']
            self.num_val = data_provider['num_val']
            self.train_generator = data_provider['train_generator']
            self.train_initializer, self.train_iterator = self.train_generator
            if data_provider['val_generator'] is not None:
                self.val_generator = data_provider['val_generator']
                self.val_initializer, self.val_iterator = self.val_generator

        self.global_step = tf.get_variable(name='global_step', initializer=tf.constant(0), trainable=False)
        self.is_training = True

        self._define_inputs()
        self._build_graph()
        self._create_saver()
        if self.mode == 'train':
            self._create_summary()
        self._init_session()