如何在Tensorflow中从冻结的模型(pb文件)中找到output_node_names?

时间:2019-01-11 19:49:54

标签: python tensorflow

我正在尝试将我的Frozen_model.pb转换为基于Tensorflow的SSD Mobilenet V2 COCO预训练模型的tensorflow JS兼容(.pb)文件。我被困在如何获取使用tensorflowjs_converter时需要的output_node_names参数。如何知道输出节点名称?

我试图使用下面的Python脚本来获取操作名称,但无法理解哪个是输出节点。

def load_graph(model_file):
  graph = tf.Graph()
  graph_def = tf.GraphDef()

  with open(model_file, "rb") as f:
    graph_def.ParseFromString(f.read())
  with graph.as_default():
    tf.import_graph_def(graph_def)

  return graph

graph = load_graph('frozen_model.pb')
ops = graph.get_operations()

1 个答案:

答案 0 :(得分:1)

首先,您可以如下检查graph_def中的所有节点:

for node in graph_def.node
    print(node.name)

或者,如果您想直观地查看图形并确定将哪个节点用作输出,则可以使用TensorBoard。有一个名为import_pb_to_tensorboard的工具。本质上,它是使用少量线将图形写入log_dir,您可以将其指向张量板。您可以简单地将这些行复制到您自己的脚本中,以实现相同的功能,而无需使用tensorflow回购构建。

第三,还有一个名为summarize_graph tool的工具:

bazel build tensorflow/tools/graph_transforms:summarize_graph
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/path/to/your/graph.pb