import tensorflow as tf from pathlib import Path import datetime import random import numpy import sys print(sys.argv) version = '/Users/dawang/Desktop/tfos_test/models/1575611737' testfile = '/Users/dawang/Desktop/tfos_test/test/part-00002' print('version', version) print('testfile', testfile) pb_file_path = version # 输入的参数 input_dim = 500 def build_example(line): parts = line.split(' ') label = int(parts[0]) if label > 1: label = 1 indice_list = [] items = parts[1:] for item in items: index = int(item.split(':')[0]) if index >= input_dim: continue indice_list += [[0, index]] value_list = [1 for i in range(len(indice_list))] shape_list = [1, input_dim] indice_list = numpy.asarray(indice_list) value_list = numpy.asarray(value_list) shape_list = numpy.asarray(shape_list) return indice_list, value_list, shape_list, label # 一定要放在 with 里,不然 导出的 graph 不带变量和参数 with tf.Session(graph=tf.Graph()) as sess: meta_graph_def = tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], pb_file_path) signature = meta_graph_def.signature_def # print(signature) signature_key = "predict" y_tensor_name = signature[signature_key].outputs["probabilities"].name y = sess.graph.get_tensor_by_name(y_tensor_name) print('y', y_tensor_name) indices = signature[signature_key].inputs['indices'].name indice_tensor = sess.graph.get_tensor_by_name(indices) print('indices', indices) values = signature[signature_key].inputs["values"].name value_tensor = sess.graph.get_tensor_by_name(values) print('values', values) shape = signature[signature_key].inputs["dense_shape"].name shape_tensor = sess.graph.get_tensor_by_name(shape) print('shape', shape) # 每行读取 testfile one_count = 0 zero_count = 0 pone_count = 0 count = 0 with open('{}.result'.format(version), 'w') as w: with open(testfile, 'r') as f: lines = f.readlines() for line in lines: line = line.strip('\n') indice_list, value_list, shape_list, label = build_example(line) if len(indice_list) == 0: continue if label == 1: one_count = one_count + 1 else: zero_count = zero_count + 1 if zero_count % 10000 == 0: print('working....', zero_count) # 这样就可以进行预测 y_out = sess.run(y[:, 0], feed_dict={ indice_tensor: indice_list, value_tensor: value_list, shape_tensor: shape_list}) count = count + 1 print(count) if y_out[0] < 0.5: pone_count = pone_count + 1 w.write('{} {}\n'.format(label, y_out[0])) print('#one', one_count) print('#zero', zero_count) print('#pone', pone_count) # test_auc = metrics.roc_auc_score(y_test_true, y_test_pred) # print(f'[test auc] {test_auc}') # train_auc = metrics.roc_auc_score(y_train_true, y_train_pred) # print(f'[train_auc] {train_auc}')