使用经过训练的对象检测API模型和TF 2进行批量预测

时间:2020-09-02 09:40:50

标签: tensorflow object batch-processing prediction detection

我在TPU上使用针对TF 2的对象检测API成功地训练了一个模型,该模型另存为.pb(SavedModel格式)。然后,我使用tf.saved_model.load将其加载回去,当使用将单个图像转换为形状为(1, w, h, 3)的张量来预测框时,它可以很好地工作。

import tensorflow as tf
import numpy as np

# Load Object Detection APIs model
detect_fn = tf.saved_model.load('/path/to/saved_model/')

image = tf.io.read_file(image_path)
image_np = tf.image.decode_jpeg(image, channels=3).numpy()
input_tensor = np.expand_dims(image_np, 0)
detections = detect_fn(input_tensor) # This works fine

问题是我需要进行批量预测才能将其缩放到100万张图像,但是此模型的输入签名似乎仅限于处理形状为(1, w, h, 3)的数据。 这也意味着我不能在Tensorflow Serving中使用批处理。 我怎么解决这个问题?我可以只更改Model Signature来处理大量数据吗?

所有工作(加载模型+预测)都是在由对象检测API(来自here)发布的官方容器中进行的

1 个答案:

答案 0 :(得分:2)

我最近遇到了这个问题。当您使用exporter_main_v2.py将检查点文件转换为.pb文件时,它将调用exporter_lib_v2.py。我发现在文件exporter_lib_v2.pyhere)中,TF2用形状[1, None, None, 3]固定了输入签名。我们必须将其更改为[None, None, None, 3]

需要将该文件(138162170185)中的行从1修改为None。然后重建TF2对象检测器API存储库(link),并使用新构建的版本再次导出.pb