我正在研究对象检测问题,在该问题中,我使用darknet获得了经过训练的模型(.cfg和.weights文件)。为了将其转换为tensorRT,我首先必须使用此repo转换为tensorflow,它会生成以下文件:
然后我使用以下代码将其转换为tensorRT
from tensorflow.python.compiler.tensorrt
import trt_convert as trt
import tensorflow as tf
with tf.Session() as sess:
# convert into frozen graph
saver = tf.train.import_meta_graph('data/yolo-obj.ckpt.meta')
saver.restore(sess, 'data/yolo-obj.ckpt')
output_nodes = ["save/restore_all"]
frozen_graph = tf.graph_util.convert_variables_to_constants(
sess,
tf.get_default_graph().as_graph_def(),
output_node_names=output_nodes)
# convert into trt
converter = trt.TrtGraphConverter(
input_graph_def=frozen_graph,
nodes_blacklist=output_nodes)
trt_graph = converter.convert()
output_node = tf.import_graph_def(
trt_graph,
return_elements=output_nodes)
sess.run(output_node)
但是我得到了错误:
ValueError:节点save / Assign的输入0从yolov3 / convolutional1 / BatchNorm / beta:0传递给float,与预期的float_ref不兼容。