如何计算使用从.pb文件加载的图形定义的张量流模型中可训练参数的总数?

时间:2018-05-03 15:51:43

标签: tensorflow neural-network convolutional-neural-network

我想计算张量流模型中的参数。它类似于现有问题如下。

How to count total number of trainable parameters in a tensorflow model?

但是如果使用从.pb文件加载的图形定义模型,则所有建议的答案都不起作用。基本上我用以下函数加载了图形。

def load_graph(model_file):

  graph = tf.Graph()
  graph_def = tf.GraphDef()

  with open(model_file, "rb") as f:
    graph_def.ParseFromString(f.read())

  with graph.as_default():
    tf.import_graph_def(graph_def)

  return graph

一个例子是在tensorflow-for-poets-2中加载frozen_graph.pb文件以进行重新训练。

https://github.com/googlecodelabs/tensorflow-for-poets-2

1 个答案:

答案 0 :(得分:0)

据我了解,GraphDef没有足够的信息来描述Variables。正如here所述,您需要MetaGraph,其中包含GraphDefCollectionDef,这是一张可以描述Variables的地图。因此,以下代码应该为我们提供正确的可训练变量计数。

导出MetaGraph:

import tensorflow as tf

a = tf.get_variable('a', shape=[1])
b = tf.get_variable('b', shape=[1], trainable=False)
init = tf.global_variables_initializer()
saver = tf.train.Saver([a])

with tf.Session() as sess:
    sess.run(init)
    saver.save(sess, r'.\test')

导入MetaGraph并计算可训练参数的总数。

import tensorflow as tf

saver = tf.train.import_meta_graph('test.meta')

with tf.Session() as sess:
    saver.restore(sess, 'test')

total_parameters = 0
for variable in tf.trainable_variables():
    total_parameters += 1
print(total_parameters)