如何在张量流图中得到每个节点的输入形状?

时间:2018-03-22 03:16:29

标签: python tensorflow caffe

嗨:现在我正在努力将张量流检查点模型转换为caffe模型。我已成功读取图形并已提取每个节点中的attr值。我得到了' dilations',' strides'和'填充' attr in" Conv2D"节点和'#34;权重"中的形状节点,但我无法获得“塑造”的价值。 attr,它在Conv2D的输入节点中是空的。但是,这些形状显示在张量板图中。 这是我的代码:

new_saver = tf.train.import_meta_graph(meta_path)          
new_saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
graph_def = sess.graph_def
node_list = graph_def.node

# conv_node, weight_node, from_node are all in node_list
# conv_node: the conv2d node in graph_def
# weight_node: the weights node of conv2d
# from_node: the input feature map node of conv2d

weight_shape_attr = weight_node.attr['shape']
weight_shapes = [dim.size for dim in weight_shape_attr.shape.dim]

strides = [ii for ii in conv_node.attr['strides'].list.i]
dilations = [ii for ii in conv_node.attr['dilations'].list.i]

shapes = from_node.attr['shape']  # this is empty

和张量板图: tensorboard_graph

请注意,Conv2D节点的输入形状为?x79x79x32,它必须存储在模型文件中的某个位置。任何人都可以提供帮助吗?任何点击都会有所帮助,谢谢。

1 个答案:

答案 0 :(得分:4)

Tensorflow图表的as_graph_def方法具有可选参数add_shapes(默认情况下为False)。如果设置为True,则会产生节点的其他属性:_output_shapes

所以你可以尝试以这种方式获取GraphDef:

graph_def = sess.graph.as_graph_def(add_shapes=True)