将TF 2对象检测模型转换为TensorRT

时间:2020-08-04 05:48:16

标签: python python-3.x tensorflow tensorrt

我正在尝试将新的Tensorflow 2对象检测API的EfficientDet D1 640x640转换为可在Jetson AGX板上运行的TensorRT(TRT)模型。

我正在运行以下代码:

allcategorydatawithlevel

我遇到以下错误:

categorydata

该错误消息说,它期望输入形状为[1,?,?,3],但得到的形状为[640,640,3]。但是我正在传递大小为(1,640,640,3)的数组,这应该是正确的。但出于某些原因,它似乎不起作用。

我正在使用的模型是Tensorflow 2.0对象检测Api(https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf2_detection_zoo.md)的预训练模型

在此先感谢您的帮助!

2 个答案:

答案 0 :(得分:0)

您应该用( ,)装饰您的收益值

def my_input_fn():
  for _ in range(num_runs):
    inp1 = np.random.normal(size=(1,640,640,3)).astype(np.uint8)
    yield (inp1,)

此外,我不确定TensorRT是否支持uint8,可能只是int8。

答案 1 :(得分:0)

尝试使用此代码,希望对您有所帮助:

from tensorflow.python.compiler.tensorrt import trt_convert as trt
import numpy as np

input_saved_model_dir = "your-model-dir"
output_saved_model_dir = "your-output-dir"

conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS
conversion_params = conversion_params._replace(
    max_workspace_size_bytes=(1<<32))
conversion_params = conversion_params._replace(precision_mode="FP16")
conversion_params = conversion_params._replace(
    maximum_cached_engines=100)

converter = trt.TrtGraphConverterV2(
    input_saved_model_dir=input_saved_model_dir,
    conversion_params=conversion_params)
converter.convert()

def my_input_fn():
  Inp1 = np.random.normal(size=(1, 640, 640, 3)).astype(np.uint8)
  yield (Inp1,)

converter.build(input_fn=my_input_fn)
converter.save(output_saved_model_dir)