冻结模型时出错(freeze_graph)

时间:2017-09-14 15:38:28

标签: tensorflow

我在tensorflow中非常新,并且希望在C ++环境中使用预先训练的模型(Python)进行推理。据我所知,为此,我需要使用" freeze_graph"冻结训练好的模型。工具。

这是一个代码片段,它如何查找非常简单的MNIST模型:

with tf.Session(config=config) as s:
    s.run(tf.global_variables_initializer())

    for i in range(n):
        batch = mnist.train.next_batch(50)
        train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

    print('test accuracy %g' % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
    saver.save(s, 'models/saved_checkpoint')

with tf.Session(config=config) as s:
    # save the graph definition
    tf.train.write_graph(s.graph_def, 'models', "graph_def.pbtxt")

freeze_graph.freeze_graph(input_graph = "models/graph_def.pbtxt", input_saver = "", input_binary = False, input_checkpoint = "models/saved_checkpoint", output_node_names = "output_node", restore_op_name = "save/restore_all", filename_tensor_name = "save/Const:0", output_graph = "frozen_graph.pb", clear_devices = True, initializer_nodes = "")

这样做我收到以下错误:

  

文件" mnist.py",第180行,在main中       output_graph =" frozen_graph.pb",clear_devices = True,initializer_nodes ="")
  文件   " /usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/freeze_graph.py" ;,   第184行,在freeze_graph中       variable_names_blacklist)
  文件" /usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/freeze_graph.py",   第87行,在freeze_graph_with_def_protos中       _ = importer.import_graph_def(input_graph_def,name ="")
   文件" /usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/importer.py",   第313行,在import_graph_def中       op_def = op_def)File" /usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py",   第2633行,在create_op中       self._add_op(RET)
  文件" /usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py",   第2312行,在_add_op中       "已被使用" %op.name)ValueError:无法添加名称为conv1 / Variable / Adam的op,因为该名称已被使用

有人知道这里可能有什么问题吗?我使用tensorflow 1.3和python 2.7。不幸的是,我找不到有关图冻结的大量信息,可用的示例对我不起作用......

提前感谢任何建议!

最佳, 阿列克谢

1 个答案:

答案 0 :(得分:1)

我能够使用Tensorflow-GPU 1.3冻结图形。我在虚拟环境中安装了tensorflow,所以' freeze_graph.py'是虚拟环境路径。

用于冻结图表的命令:

python /home/ck/venvs/enet/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py --input_graph ./log/graph.pbtxt --input_checkpoint ./log/model.ckpt-0 --output_graph ./log/frozen_model.pb --output_node_names=ENet/logits_to_softmax

这里' log'是保存检查点以及graph.pbtxt的文件夹。

注意:一旦检查点和pbtxt文件被保存,我就从命令行执行了此操作。我还没有尝试过你描述的方法,但是如果你的目的只是为了冻结图形,那么我想这应该可行。