如何知道已保存模型中的输出和输入张量名称

时间:2019-03-23 13:01:09

标签: tensorflow

我知道如何加载已保存的TensorFlow模型,但我将如何知道输入和输出张量名称。

我可以使用tf.import_graph_def加载protobuf文件,然后使用get_tensor_by_name函数加载张量,但是我怎么知道任何预训练模型的张量名称。我需要检查他们的文档还是其他方法?

3 个答案:

答案 0 :(得分:1)

假设输入和输出张量是占位符,则类似以下内容将对您有所帮助:

X = np.ones((1,3), dtype=np.float32)
tf.reset_default_graph()
model_saver = tf.train.Saver(defer_build=True)
input_pl = tf.placeholder(tf.float32, shape=[1,3], name="Input")
w = tf.Variable(tf.random_normal([3,3], stddev=0.01), name="Weight")
b = tf.Variable(tf.zeros([3]), name="Bias")
output = tf.add(tf.matmul(input_pl, w), b)
model_saver.build()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
model_saver.save(sess, "./model.ckpt")

现在,该图已构建并保存,我们可以看到如下所示的占位符名称:

model_loader = tf.train.Saver()
sess = tf.Session()
model_loader.restore(sess, "./model.ckpt")
placeholders = [x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"]
# [<tf.Operation 'Input' type=Placeholder>]

答案 1 :(得分:0)

仅输入解决方案:

# read pb into graph_def
with tf.gfile.GFile(input_model_filepath, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

# import graph_def
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def)

# print operations
for op in graph.get_operations():
    if op.type == "Placeholder":
        print(op.name)

答案 2 :(得分:0)

您可以检查图中每个操作的名称和输入列表,以找到张量的名称。

with tf.gfile.GFile(input_model_filepath, "rb") as f:
  graph_def = tf.GraphDef()
  graph_def.ParseFromString(f.read())

with tf.Graph().as_default() as graph:
  tf.import_graph_def(graph_def)

for op in graph.get_operations():
  print(op.name, [inp for inp in op.inputs])