如何从TensorFlow .pb模型中获取权重格式?

时间:2017-07-12 07:52:08

标签: c++ tensorflow model

我想重新组织tensorflow .pb模型的节点,所以我首先从GraphDef获取NodeDef,然后使用NodeDef.attr()获取attr用于“Conv2D”的节点。    我可以从attr获取strides,padding,data_format,use_cudnn_on_gpu等参数,但是无法获取权重格式参数。    我使用的语言是c ++。    怎么弄它!谢谢!

1 个答案:

答案 0 :(得分:4)

Conv2D有两个输入:第一个是数据,第二个是filter(或权重),因此您只需检查Conv2D的第二个输入的格式即可。如果您使用的是C ++,可以试试这个:

# Assuming inputs: conv2d_node, node_map.
filter_node_name = conv2d_node.input(1)
filter_node = node_map[filter_node_name]
# You might need to check identity node here.
# Get the shape of filter_node using NodeDef.attr()