我知道如何加载已保存的TensorFlow模型,但我将如何知道输入和输出张量名称。
我可以使用tf.import_graph_def加载protobuf文件,然后使用get_tensor_by_name函数加载张量,但是我怎么知道任何预训练模型的张量名称。我需要检查他们的文档还是其他方法?
答案 0 :(得分:1)
假设输入和输出张量是占位符,则类似以下内容将对您有所帮助:
X = np.ones((1,3), dtype=np.float32)
tf.reset_default_graph()
model_saver = tf.train.Saver(defer_build=True)
input_pl = tf.placeholder(tf.float32, shape=[1,3], name="Input")
w = tf.Variable(tf.random_normal([3,3], stddev=0.01), name="Weight")
b = tf.Variable(tf.zeros([3]), name="Bias")
output = tf.add(tf.matmul(input_pl, w), b)
model_saver.build()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
model_saver.save(sess, "./model.ckpt")
现在,该图已构建并保存,我们可以看到如下所示的占位符名称:
model_loader = tf.train.Saver()
sess = tf.Session()
model_loader.restore(sess, "./model.ckpt")
placeholders = [x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"]
# [<tf.Operation 'Input' type=Placeholder>]
答案 1 :(得分:0)
仅输入解决方案:
# read pb into graph_def
with tf.gfile.GFile(input_model_filepath, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# import graph_def
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
# print operations
for op in graph.get_operations():
if op.type == "Placeholder":
print(op.name)
答案 2 :(得分:0)
您可以检查图中每个操作的名称和输入列表,以找到张量的名称。
with tf.gfile.GFile(input_model_filepath, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
for op in graph.get_operations():
print(op.name, [inp for inp in op.inputs])