嗨:现在我正在努力将张量流检查点模型转换为caffe模型。我已成功读取图形并已提取每个节点中的attr值。我得到了' dilations',' strides'和'填充' attr in" Conv2D"节点和'#34;权重"中的形状节点,但我无法获得“塑造”的价值。 attr,它在Conv2D的输入节点中是空的。但是,这些形状显示在张量板图中。 这是我的代码:
new_saver = tf.train.import_meta_graph(meta_path)
new_saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
graph_def = sess.graph_def
node_list = graph_def.node
# conv_node, weight_node, from_node are all in node_list
# conv_node: the conv2d node in graph_def
# weight_node: the weights node of conv2d
# from_node: the input feature map node of conv2d
weight_shape_attr = weight_node.attr['shape']
weight_shapes = [dim.size for dim in weight_shape_attr.shape.dim]
strides = [ii for ii in conv_node.attr['strides'].list.i]
dilations = [ii for ii in conv_node.attr['dilations'].list.i]
shapes = from_node.attr['shape'] # this is empty
和张量板图: tensorboard_graph
请注意,Conv2D节点的输入形状为?x79x79x32,它必须存储在模型文件中的某个位置。任何人都可以提供帮助吗?任何点击都会有所帮助,谢谢。
答案 0 :(得分:4)
Tensorflow图表的as_graph_def
方法具有可选参数add_shapes
(默认情况下为False
)。如果设置为True
,则会产生节点的其他属性:_output_shapes
。
所以你可以尝试以这种方式获取GraphDef:
graph_def = sess.graph.as_graph_def(add_shapes=True)