我想使用MMdnn将tensorflow ResNet模型转换为其他框架。看来我只能使用mmconvert从.pb冻结的图形文件中读取。
但是,使用tf.estimator.Estimator时,它创建的.pb文件是SavedModelDef。我知道这是tf GraphDef的包装。因此,可以使用freeze_graph.py从SavedModel中提取GraphDef .pb文件。
从那里,我需要tf GraphDef中输入节点的名称。但是我不确定如何通过查看.pbtxt来识别名称。根据框架,tf.Estimator输入带有tf.Dataset对象。
我猜测应该在接受输入的地方有一个tf.Placeholder。但是我不确定如何找到输入节点的实际位置。
答案 0 :(得分:1)
在这里回答我自己的问题。 tensorflow附带的freeze_graph实用程序可用于从tf SavedModel格式提取graphdef。
要查找输入节点的名称,请确保将tf SavedModel保存为pbtxt格式。打开它并查找您的计算图的第一个节点,例如如果使用tf resnet,则第一个节点将命名为resnet_model / *。查找提供该节点的节点,您将拥有输入节点的名称,以指定给MMdnn工具。我希望这是Estimator为输入添加的tf.Placeholder。该节点仅被命名为Placeholder
,所以这就是我指定的输入节点。
首先提取计算图。
freeze_graph --input_saved_model_dir <path/to/saved_model_dir> --output_node_names softmax --output_graph ./graph_def.pb
然后使用MMdnn将其转换为caffe。
mmconvert -sf tensorflow -iw ./graph_def.pb --inNodeName Placeholder --inputShape 224,224,3 --dstNodeName softmax -df caffe -om tf_resnet