如何从Tensorflow中的.pb模型中获取权重

时间:2017-09-09 05:36:53

标签: python tensorflow

我训练了一个模型,然后通过冻结该模型创建一个.pb文件。 所以,我的问题是如何从.pb文件中获取权重,或者我必须为获取权重做更多的处理

@mrry,请指导我。

1 个答案:

答案 0 :(得分:12)

我们首先从.pb文件加载图表。

import tensorflow as tf
from tensorflow.python.platform import gfile

GRAPH_PB_PATH = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb' #path to your .pb file
with tf.Session(config=config) as sess:
  print("load graph")
  with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
    graph_nodes=[n for n in graph_def.node]

现在,当您将图表冻结到.pb文件时,您的变量会转换为Const类型,而作为trainabe变量的权重也会在Const中存储为.pb文件。 graph_nodes包含图表中的所有节点。但我们对所有Const类型节点感兴趣。

wts = [n for n in graph_nodes if n.op=='Const']

wts的每个元素都是NodeDef类型。它有几个属性,如name,op等。可以按如下方式提取值 -

from tensorflow.python.framework import tensor_util

for n in wts:
    print "Name of the node - %s" % n.name
    print "Value - " 
    print tensor_util.MakeNdarray(n.attr['value'].tensor)

希望这能解决您的问题。