恢复已保存的Tensorflow .pb模型的权重

时间:2017-07-27 23:31:01

标签: python tensorflow

我在这里看过许多关于恢复已保存的TF模型的帖子,但没有人可以回答我的问题。 Using TF 1.0.0

具体而言,我有兴趣查看inceptionv3文件.pb中公开提供的tensorboard模型的权重。我设法使用一小段Python代码恢复它,并可以访问from tensorflow.python.platform import gfile INCEPTION_LOG_DIR = '/tmp/inception_v3_log' if not os.path.exists(INCEPTION_LOG_DIR): os.makedirs(INCEPTION_LOG_DIR) with tf.Session() as sess: model_filename = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb' with gfile.FastGFile(model_filename, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) _= tf.import_graph_def(graph_def,name='') writer = tf.train.SummaryWriter(INCEPTION_LOG_DIR, graph_def) writer=tf.summary.FileWriter(INCEPTION_LOG_DIR, graph_def) writer.close() 中的图形高级视图:

tensors= tf.import_graph_def(graph_def,name='')

但是,我无法访问任何图层'权重。

return_elements=
即使我添加了任意{{1}},

也会返回空。它有任何重量吗?如果是,那么适当的程序是什么?谢谢。

4 个答案:

答案 0 :(得分:4)

使用此代码打印张量值:

with tf.Session() as sess:
    print sess.run('your_tensor_name')

您可以使用此代码来检索张量名称:

    op = sess.graph.get_operations()
    for m in op : 
    print(m.values())

答案 1 :(得分:2)

恢复重量和打印它们之间存在差异。前一个表示想要从已保存的ckpt文件中导入重量值以进行再训练或推理,而后者可能用于检查。此外,.pb文件将模型参数编码为tf.constant()ops。因此,模型参数不会出现在tf.trainable_variables()中,因此您无法直接使用.pb恢复权重。根据你的问题,我认为你只是想看到'检查的重量。

我们首先从.pb文件加载图表。

import tensorflow as tf
from tensorflow.python.platform import gfile

GRAPH_PB_PATH = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb' #path to your .pb file
with tf.Session(config=config) as sess:
  print("load graph")
  with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
    graph_nodes=[n for n in graph_def.node]

现在,当您将图表冻结到.pb文件时,您的变量会转换为Const类型,而作为trainabe变量的权重也会在Const中存储为.pb文件。 graph_nodes包含图表中的所有节点。但我们对所有Const类型节点感兴趣。

wts = [n for n in graph_nodes if n.op=='Const']

wts的每个元素都是NodeDef类型。它有几个属性,如name,op等。可以按如下方式提取值 -

from tensorflow.python.framework import tensor_util

for n in wts:
    print "Name of the node - %s" % n.name
    print "Value - " 
    print tensor_util.MakeNdarray(n.attr['value'].tensor)

希望这能解决您的问题。

答案 2 :(得分:1)

您可以使用此代码获取张量的名称。

[tf.get_default_graph()中张量的tensor.name。as_graph_def()。node]

答案 3 :(得分:0)

仅是小型工具即可打印.pb模型权重:

sorted = [...showList].sort(...)