我有一个保存在pb文件中的模型。我希望能计算出它的触发器。我的示例代码如下:
import tensorflow as tf
import sys
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat
pb_file = 'themodel.pb'
run_meta = tf.RunMetadata()
with tf.Session() as sess:
print("load graph")
with gfile.FastGFile(pb_path,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
options=tf.profiler.ProfileOptionBuilder.float_operation())
print("test flops:{:,}".format(flops.total_float_ops))
打印信息很奇怪。我的模型有几十层,但是在打印的信息中只报告了18个触发器。我非常确定该模型已正确加载,因为如果尝试按如下所示打印每层的名称,则:
print([n.name for n in tf.get_default_graph().as_graph_def().node])
打印信息显示正确的网络。
我的代码怎么了?
谢谢!
答案 0 :(得分:1)
我想我找到了问题的原因和解决方案。以下代码可以打印给定pb文件的触发器。
import os
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import importer
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
pb_path = 'mymodel.pb'
run_meta = tf.RunMetadata()
with tf.Graph().as_default():
output_graph_def = graph_pb2.GraphDef()
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = importer.import_graph_def(output_graph_def, name="")
print('model loaded!')
all_keys = sorted([n.name for n in tf.get_default_graph().as_graph_def().node])
# for k in all_keys:
# print(k)
with tf.Session() as sess:
flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
options=tf.profiler.ProfileOptionBuilder.float_operation())
print("test flops:{:,}".format(flops.total_float_ops))
问题中打印的触发器只有18个的原因是,在生成pb文件时,我将输入图像的形状设置为[None, None, 3]
。如果我将其更改为[500, 500, 3]
,则印刷的拖鞋将是正确的。
答案 1 :(得分:0)
不确定在不知道输入和输出的情况下如何计算性能指标:也许它需要CallableOptions?我会使用trace_next_step
and a Session
而不是手动进行计算。