当尝试将Squeezenet1.1加载到TensorFlow中时,出现以下错误消息:
import tensorflow as tf
import os
import numpy as np
from tensorflow.core.framework import graph_pb2
graph_def = graph_pb2.GraphDef()
with open(os.path.join(script_dir, 'squeezenet.pb'), "rb") as f:
graph_def.ParseFromString(f.read())
test_graph = tf.Graph()
with test_graph.as_default() as graph:
tf.import_graph_def(graph_def)
with tf.Session(graph=graph) as sess:
data = graph.get_tensor_by_name("import/data:0")
data_op = graph.get_operation_by_name("import/data")
random_input = np.random.rand(1, 3, 224, 224).astype(np.float)
sess.run(data_op, feed_dict={data: random_input})
我的目标是加载以ONNX文件(https://github.com/onnx/models/tree/master/models/image_classification/squeezenet)形式给出的Squeezenet,然后首先将其另存为.pb文件。给定的.pb文件,我使用以下代码执行一个推断:
tf.contrib.framework.get_placeholders(graph)
有趣的是,尽管January 15, 2018 - January 18, 2018 - Nothing needs to happen
January 28, 2018 - February 2, 2018 would need to be split into two
明确指出,只有一个占位符,“导入/数据”操作仍需要feed_dict,而Tensorboard也显示此操作需要该占位符。
感谢您的帮助!