如何使用tf.data API在冻结的推理图上运行推理?

时间:2019-07-31 21:29:23

标签: python tensorflow

我已使用以下脚本从检查点创建了frozen_inference_graph.pb

import tensorflow as tf
import numpy as np

meta_path = './model.ckpt-13800.meta'
output_node_names = ['outputs/Softmax']
with tf.Session() as sess:
    saver = tf.train.import_meta_graph(meta_path, clear_devices=True)
    saver.restore(sess, './model.ckpt-13800')
    frozen_graph_def = tf.graph_util.convert_variables_to_constants(sess,
                                                                    sess.graph_def, 
                                                                    output_node_names)
    with open('frozen_inference_graph.pb', 'wb') as f:
        f.write(frozen_graph_def.SerializeToString())

现在,我想使用tf.data在此模型上进行推断。我当前的结构如下:

def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def, 
            input_map=None, 
            return_elements=None, 
            name=None, 
            op_dict=None, 
            producer_op_list=None
        )
    return graph

graph = load_graph("frozen_inference_graph.pb")
x = graph.get_tensor_by_name('import/IteratorGetNext:0')
y = graph.get_tensor_by_name('import/outputs/Softmax:0')

# create a random numpy array
test_features = np.random.random(size=(10, 185, 140, 50, 1)) 
with tf.Session(graph=graph) as sess:
    pred_y = sess.run(y, feed_dict={x: test_features} )
    print(pred_y)

如何使用tf.data API来使用frozen_inference_graph

0 个答案:

没有答案