我正在尝试将新的Tensorflow 2对象检测API的EfficientDet D1 640x640转换为可在Jetson AGX板上运行的TensorRT(TRT)模型。
我正在运行以下代码:
allcategorydatawithlevel
我遇到以下错误:
categorydata
该错误消息说,它期望输入形状为[1,?,?,3],但得到的形状为[640,640,3]。但是我正在传递大小为(1,640,640,3)的数组,这应该是正确的。但出于某些原因,它似乎不起作用。
我正在使用的模型是Tensorflow 2.0对象检测Api(https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md)的预训练模型
在此先感谢您的帮助!
答案 0 :(得分:0)
您应该用( ,)
装饰您的收益值
def my_input_fn():
for _ in range(num_runs):
inp1 = np.random.normal(size=(1,640,640,3)).astype(np.uint8)
yield (inp1,)
此外,我不确定TensorRT是否支持uint8,可能只是int8。
答案 1 :(得分:0)
尝试使用此代码,希望对您有所帮助:
from tensorflow.python.compiler.tensorrt import trt_convert as trt
import numpy as np
input_saved_model_dir = "your-model-dir"
output_saved_model_dir = "your-output-dir"
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS
conversion_params = conversion_params._replace(
max_workspace_size_bytes=(1<<32))
conversion_params = conversion_params._replace(precision_mode="FP16")
conversion_params = conversion_params._replace(
maximum_cached_engines=100)
converter = trt.TrtGraphConverterV2(
input_saved_model_dir=input_saved_model_dir,
conversion_params=conversion_params)
converter.convert()
def my_input_fn():
Inp1 = np.random.normal(size=(1, 640, 640, 3)).astype(np.uint8)
yield (Inp1,)
converter.build(input_fn=my_input_fn)
converter.save(output_saved_model_dir)