我已使用以下脚本从检查点创建了frozen_inference_graph.pb
:
import tensorflow as tf
import numpy as np
meta_path = './model.ckpt-13800.meta'
output_node_names = ['outputs/Softmax']
with tf.Session() as sess:
saver = tf.train.import_meta_graph(meta_path, clear_devices=True)
saver.restore(sess, './model.ckpt-13800')
frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess,
sess.graph_def,
output_node_names)
with open('frozen_inference_graph.pb', 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
现在,我想使用tf.data
在此模型上进行推断。我当前的结构如下:
def load_graph(frozen_graph_filename):
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name=None,
op_dict=None,
producer_op_list=None
)
return graph
graph = load_graph("frozen_inference_graph.pb")
x = graph.get_tensor_by_name('import/IteratorGetNext:0')
y = graph.get_tensor_by_name('import/outputs/Softmax:0')
# create a random numpy array
test_features = np.random.random(size=(10, 185, 140, 50, 1))
with tf.Session(graph=graph) as sess:
pred_y = sess.run(y, feed_dict={x: test_features} )
print(pred_y)
如何使用tf.data
API来使用frozen_inference_graph
?