将经过培训的TensorFlow模型导出为C ++

时间:2018-03-14 16:28:28

标签: python c++ tensorflow bazel

我正在尝试使用freeze_graph.py将训练有素的TensorFlow模型导出到C ++。我正在尝试使用以下语法导出ssd_mobilenet_v1_coco_2017_11_17模型:

bazel build tensorflow/python/tools:freeze_graph && \
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=frozen_inference_graph.pb \
--input_checkpoint=model.ckpt \
--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax 

终端表示构建成功但显示以下错误:

Traceback (most recent call last):
  File "/home/my_username/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 350, in <module>
    app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "/home/my_username/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 124, in run
    _sys.exit(main(argv))
  File "/home/my_username/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 249, in main
    FLAGS.saved_model_tags)
  File "/home/my_username/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 227, in freeze_graph
    input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
  File "/home/my_username/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 171, in _parse_input_graph_proto
    text_format.Merge(f.read(), input_graph_def)
  File "/home/my_username/.cache/bazel/_bazel_my_username/3572bc2aff1de1dd37356cf341944e54/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/python/tools/freeze_graph.runfiles/protobuf_archive/python/google/protobuf/text_format.py", line 525, in Merge
    descriptor_pool=descriptor_pool)
  File "/home/my_username/.cache/bazel/_bazel_my_username/3572bc2aff1de1dd37356cf341944e54/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/python/tools/freeze_graph.runfiles/protobuf_archive/python/google/protobuf/text_format.py", line 579, in MergeLines
    return parser.MergeLines(lines, message)
  File "/home/my_username/.cache/bazel/_bazel_my_username/3572bc2aff1de1dd37356cf341944e54/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/python/tools/freeze_graph.runfiles/protobuf_archive/python/google/protobuf/text_format.py", line 612, in MergeLines
    self._ParseOrMerge(lines, message)
  File "/home/my_username/.cache/bazel/_bazel_my_username/3572bc2aff1de1dd37356cf341944e54/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/python/tools/freeze_graph.runfiles/protobuf_archive/python/google/protobuf/text_format.py", line 627, in _ParseOrMerge
    self._MergeField(tokenizer, message)
  File "/home/my_username/.cache/bazel/_bazel_my_username/3572bc2aff1de1dd37356cf341944e54/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/python/tools/freeze_graph.runfiles/protobuf_archive/python/google/protobuf/text_format.py", line 671, in _MergeField
    name = tokenizer.ConsumeIdentifierOrNumber()
  File "/home/my_username/.cache/bazel/_bazel_my_username/3572bc2aff1de1dd37356cf341944e54/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/python/tools/freeze_graph.runfiles/protobuf_archive/python/google/protobuf/text_format.py", line 1144, in ConsumeIdentifierOrNumber
    raise self.ParseError('Expected identifier or number, got %s.' % result)
google.protobuf.text_format.ParseError: 2:1 : Expected identifier or number, got `.

再次运行该命令时,我收到了这条消息:

WARNING: /home/my_username/tensorflow/tensorflow/core/BUILD:1814:1: in includes attribute of cc_library rule //tensorflow/core:framework_headers_lib: '../../external/nsync/public' resolves to 'external/nsync/public' not below the relative path of its package 'tensorflow/core'. This will be an error in the future. Since this rule was created by the macro 'cc_header_only_library', the error might have been caused by the macro implementation in /home/my_username/tensorflow/tensorflow/tensorflow.bzl:1138:30
WARNING: /home/my_username/tensorflow/tensorflow/core/BUILD:1814:1: in includes attribute of cc_library rule //tensorflow/core:framework_headers_lib: '../../external/nsync/public' resolves to 'external/nsync/public' not below the relative path of its package 'tensorflow/core'. This will be an error in the future. Since this rule was created by the macro 'cc_header_only_library', the error might have been caused by the macro implementation in /home/my_username/tensorflow/tensorflow/tensorflow.bzl:1138:30
WARNING: /home/my_username/tensorflow/tensorflow/contrib/learn/BUILD:15:1: in py_library rule //tensorflow/contrib/learn:learn: target '//tensorflow/contrib/learn:learn' depends on deprecated target '//tensorflow/contrib/session_bundle:exporter': No longer supported. Switch to SavedModel immediately.
WARNING: /home/my_username/tensorflow/tensorflow/contrib/learn/BUILD:15:1: in py_library rule //tensorflow/contrib/learn:learn: target '//tensorflow/contrib/learn:learn' depends on deprecated target '//tensorflow/contrib/session_bundle:gc': No longer supported. Switch to SavedModel immediately.
INFO: Analysed target //tensorflow/python/tools:freeze_graph (0 packages loaded).
INFO: Found 1 target...
Target //tensorflow/python/tools:freeze_graph up-to-date:
  bazel-bin/tensorflow/python/tools/freeze_graph
INFO: Elapsed time: 0.419s, Critical Path: 0.00s
INFO: Build completed successfully, 1 total action
Traceback (most recent call last):
  File "/home/my_username/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 350, in <module>
    app.run(main=main, argv=[sys.argv[0]] + unparsed)
  File "/home/my_username/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/platform/app.py", line 124, in run
    _sys.exit(main(argv))
  File "/home/my_username/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 249, in main
    FLAGS.saved_model_tags)
  File "/home/my_username/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 227, in freeze_graph
    input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
  File "/home/my_username/tensorflow/bazel-bin/tensorflow/python/tools/freeze_graph.runfiles/org_tensorflow/tensorflow/python/tools/freeze_graph.py", line 171, in _parse_input_graph_proto
    text_format.Merge(f.read(), input_graph_def)
  File "/home/my_username/.cache/bazel/_bazel_my_username/3572bc2aff1de1dd37356cf341944e54/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/python/tools/freeze_graph.runfiles/protobuf_archive/python/google/protobuf/text_format.py", line 525, in Merge
    descriptor_pool=descriptor_pool)
  File "/home/my_username/.cache/bazel/_bazel_my_username/3572bc2aff1de1dd37356cf341944e54/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/python/tools/freeze_graph.runfiles/protobuf_archive/python/google/protobuf/text_format.py", line 579, in MergeLines
    return parser.MergeLines(lines, message)
  File "/home/my_username/.cache/bazel/_bazel_my_username/3572bc2aff1de1dd37356cf341944e54/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/python/tools/freeze_graph.runfiles/protobuf_archive/python/google/protobuf/text_format.py", line 612, in MergeLines
    self._ParseOrMerge(lines, message)
  File "/home/my_username/.cache/bazel/_bazel_my_username/3572bc2aff1de1dd37356cf341944e54/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/python/tools/freeze_graph.runfiles/protobuf_archive/python/google/protobuf/text_format.py", line 627, in _ParseOrMerge
    self._MergeField(tokenizer, message)
  File "/home/my_username/.cache/bazel/_bazel_my_username/3572bc2aff1de1dd37356cf341944e54/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/python/tools/freeze_graph.runfiles/protobuf_archive/python/google/protobuf/text_format.py", line 671, in _MergeField
    name = tokenizer.ConsumeIdentifierOrNumber()
  File "/home/my_username/.cache/bazel/_bazel_my_username/3572bc2aff1de1dd37356cf341944e54/execroot/org_tensorflow/bazel-out/k8-opt/bin/tensorflow/python/tools/freeze_graph.runfiles/protobuf_archive/python/google/protobuf/text_format.py", line 1144, in ConsumeIdentifierOrNumber
    raise self.ParseError('Expected identifier or number, got %s.' % result)
google.protobuf.text_format.ParseError: 2:1 : Expected identifier or number, got `.

我正在导出ssd_mobilenet_v1_coco_2017_11_17,就像练习一样。我打算导出我自己训练的模型并用program测试输出。 我使用Bazel v0.11.1构建了TensorFlow 1.5。我使用TensorFlow网站提供的以下代码片段验证了安装;

# Python
import tensorflow as tf
hello = tf.constant('Hello, TensorFlow!')
sess = tf.Session()
print(sess.run(hello))

我还运行了对象检测iPython示例笔记本,它运行良好。

我在配备Intel Core i5-8250U CPU,8GB RAM,1TB HDD和NVIDIA MX150(2 GB)GPU的笔记本电脑上使用Ubuntu 17.10.1。请帮忙。如何将训练过的模型导出到C ++?

1 个答案:

答案 0 :(得分:0)

为了导出对象检测模型,我在研究/对象检测中使用了export_inference_graph.py代码。以下是运行代码的示例:

python3 export_inference_graph.py \
--input_type image_tensor \
--pipeline_config_path path/to/model.config \    
--trained_checkpoint_prefix path/to/model.ckpt-CHECKPOINTNUMBER \    
--output_directory path/to/frozen_inference_graph

然后我使用创建的frozen_inference_graph.pb与C ++代码基本上与label_image相同,几乎没有修改来运行检测模型而不是分类。