为YOLOV3优化Tensorflow图

时间:2019-10-11 14:58:20

标签: python-3.x tensorflow keras yolo

这是我在将优化的冻结的yolov3转换为.h5 keras模型并生成Frozen.pb文件后用来保存优化的冻结的yolov3的代码片段。

from tensorflow.python.tools import freeze_graph
from tensorflow.python.tools import optimize_for_inference_lib

input_node_name = 'input_1_1:0'
output_node_name = 'conv_81_1/BiasAdd,conv_93_1/BiasAdd,conv_105_1/BiasAdd'

output_frozen_graph_name = './frozen_yolo.pb'
output_optimized_graph_name = './optimized_yolo.pb'

freeze_graph.freeze_graph(input_graph ='./yolov3.pbtxt',  input_saver = '',
             input_binary = False, input_checkpoint ='./yolov3.ckpt', output_node_names = output_node_name,
             restore_op_name = 'save/restore_all', filename_tensor_name = 'save/Const:0',
             output_graph = output_frozen_graph_name, clear_devices = True, initializer_nodes = '')


input_graph_def = tf.GraphDef()

with tf.gfile.Open(output_frozen_graph_name, 'rb') as f:
    data = f.read()
    input_graph_def.ParseFromString(data)

output_graph_def = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def,
        [input_node_name], 
        [output_node_name],
        tf.float32.as_datatype_enum)

f = tf.gfile.FastGFile(output_optimized_graph_name, 'wb')
f.write(output_graph_def.SerializeToString())

这是我收到的错误消息

INFO:tensorflow:Restoring parameters from ./yolov3.ckpt
INFO:tensorflow:Froze 366 variables.
INFO:tensorflow:Converted 366 variables to const ops.
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-18-650086644cea> in <module>()
     24         [input_node_name],
     25         [output_node_name],
---> 26         tf.float32.as_datatype_enum)
     27 
     28 f = tf.gfile.FastGFile(output_optimized_graph_name, 'wb')

1 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/tools/strip_unused_lib.py in strip_unused(input_graph_def, input_node_names, output_node_names, placeholder_type_enum)
     52     if ":" in name:
     53       raise ValueError("Name '%s' appears to refer to a Tensor, "
---> 54                        "not a Operation." % name)
     55 
     56   # Here we replace the nodes we're going to override as inputs with

ValueError: Name 'input_1_1:0' appears to refer to a Tensor, not a Operation.

当我使用models.outputs.name时,得到的output_node_names如下

['conv_81_1/BiasAdd', 'conv_93_1/BiasAdd', 'conv_105_1/BiasAdd']

当我运行model.inputs.name时,我得到以下输入节点名称

['input_1_1:0']

如何解决问题?

0 个答案:

没有答案