在tensorflow对象检测API中,是否有办法知道对象检测模型有多少个参数?

时间:2019-07-01 12:12:29

标签: python tensorflow object-detection-api

我使用张量对象检测(TFOD)API训练了不同的模型,我想知道为给定模型训练了多少参数。

我运行的RCNN,SSD,RFCN速度更快,并且具有不同的图像分辨率,我想有一种方法来知道要训练多少个参数。有办法吗?

我尝试了How to count total number of trainable parameters in a tensorflow model?的答案,但没有运气。

这是我在model_main.py的第103行添加的代码:

print("Training {} parameters".format(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))

我认为问题是我没有访问TFOD正在运行的tf.Session(),因此我的代码始终返回0.0参数(尽管训练策略很好并且可以训练,希望有数百万个参数),但我没有不知道如何解决这个问题。

2 个答案:

答案 0 :(得分:0)

TFOD API使用tf.estimator.Estimator进行培训和评估。 Estimator对象提供了获取所有变量Estimator.get_variable_names()reference)的功能。

您可以在print(estimator.get_variable_names())here)之后添加estimator.train_and_evaluate()行。

培训完成后,您将看到所有打印的变量名。要更快地查看结果,您只需训练1步即可。

答案 1 :(得分:0)

使用export_inference_graph.py时,脚本还会分析您的模型,并计算参数和FLOPS(如果可能)。 如果看起来像这样:

angular