无法将tensorflow冻结模型转换为tensorRT

时间:2019-10-30 14:08:33

标签: python tensorflow darknet tensorrt

我正在研究对象检测问题,在该问题中,我使用darknet获得了经过训练的模型(.cfg和.weights文件)。为了将其转换为tensorRT,我首先必须使用此repo转换为tensorflow,它会生成以下文件:

  • 检查点
  • yolo-obj.ckpt.data-00000-of-00001
  • yolo-obj.ckpt.index
  • yolo-obj.ckpt.meta
  • yolo-obj.pb

然后我使用以下代码将其转换为tensorRT

from tensorflow.python.compiler.tensorrt 
import trt_convert as trt
import tensorflow as tf


with tf.Session() as sess:
    # convert into frozen graph
    saver = tf.train.import_meta_graph('data/yolo-obj.ckpt.meta')
    saver.restore(sess, 'data/yolo-obj.ckpt')

    output_nodes = ["save/restore_all"]

    frozen_graph = tf.graph_util.convert_variables_to_constants(
                   sess,
                   tf.get_default_graph().as_graph_def(),
                   output_node_names=output_nodes)

    # convert into trt
    converter = trt.TrtGraphConverter(
                 input_graph_def=frozen_graph,
                 nodes_blacklist=output_nodes)
    trt_graph = converter.convert()
    output_node = tf.import_graph_def(
                 trt_graph,
                 return_elements=output_nodes)

    sess.run(output_node)

但是我得到了错误:

ValueError:节点save / Assign的输入0从yolov3 / convolutional1 / BatchNorm / beta:0传递给float,与预期的float_ref不兼容。

0 个答案:

没有答案