我在这里看过许多关于恢复已保存的TF模型的帖子,但没有人可以回答我的问题。 Using TF 1.0.0
具体而言,我有兴趣查看inceptionv3
文件.pb
中公开提供的tensorboard
模型的权重。我设法使用一小段Python代码恢复它,并可以访问from tensorflow.python.platform import gfile
INCEPTION_LOG_DIR = '/tmp/inception_v3_log'
if not os.path.exists(INCEPTION_LOG_DIR):
os.makedirs(INCEPTION_LOG_DIR)
with tf.Session() as sess:
model_filename = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_= tf.import_graph_def(graph_def,name='')
writer = tf.train.SummaryWriter(INCEPTION_LOG_DIR, graph_def)
writer=tf.summary.FileWriter(INCEPTION_LOG_DIR, graph_def)
writer.close()
中的图形高级视图:
tensors= tf.import_graph_def(graph_def,name='')
但是,我无法访问任何图层'权重。
return_elements=
即使我添加了任意{{1}},也会返回空。它有任何重量吗?如果是,那么适当的程序是什么?谢谢。
答案 0 :(得分:4)
使用此代码打印张量值:
with tf.Session() as sess:
print sess.run('your_tensor_name')
您可以使用此代码来检索张量名称:
op = sess.graph.get_operations()
for m in op :
print(m.values())
答案 1 :(得分:2)
恢复重量和打印它们之间存在差异。前一个表示想要从已保存的ckpt文件中导入重量值以进行再训练或推理,而后者可能用于检查。此外,.pb
文件将模型参数编码为tf.constant()ops。因此,模型参数不会出现在tf.trainable_variables()中,因此您无法直接使用.pb
来恢复权重。根据你的问题,我认为你只是想看到'检查的重量。
我们首先从.pb
文件加载图表。
import tensorflow as tf
from tensorflow.python.platform import gfile
GRAPH_PB_PATH = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb' #path to your .pb file
with tf.Session(config=config) as sess:
print("load graph")
with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
graph_nodes=[n for n in graph_def.node]
现在,当您将图表冻结到.pb
文件时,您的变量会转换为Const
类型,而作为trainabe变量的权重也会在Const
中存储为.pb
文件。 graph_nodes
包含图表中的所有节点。但我们对所有Const
类型节点感兴趣。
wts = [n for n in graph_nodes if n.op=='Const']
wts
的每个元素都是NodeDef类型。它有几个属性,如name,op等。可以按如下方式提取值 -
from tensorflow.python.framework import tensor_util
for n in wts:
print "Name of the node - %s" % n.name
print "Value - "
print tensor_util.MakeNdarray(n.attr['value'].tensor)
希望这能解决您的问题。
答案 2 :(得分:1)
您可以使用此代码获取张量的名称。
[tf.get_default_graph()中张量的tensor.name。as_graph_def()。node]
答案 3 :(得分:0)
仅是小型工具即可打印.pb模型权重:
sorted = [...showList].sort(...)