我在TPU上使用针对TF 2的对象检测API成功地训练了一个模型,该模型另存为.pb(SavedModel格式)。然后,我使用tf.saved_model.load
将其加载回去,当使用将单个图像转换为形状为(1, w, h, 3)
的张量来预测框时,它可以很好地工作。
import tensorflow as tf
import numpy as np
# Load Object Detection APIs model
detect_fn = tf.saved_model.load('/path/to/saved_model/')
image = tf.io.read_file(image_path)
image_np = tf.image.decode_jpeg(image, channels=3).numpy()
input_tensor = np.expand_dims(image_np, 0)
detections = detect_fn(input_tensor) # This works fine
问题是我需要进行批量预测才能将其缩放到100万张图像,但是此模型的输入签名似乎仅限于处理形状为(1, w, h, 3)
的数据。
这也意味着我不能在Tensorflow Serving中使用批处理。
我怎么解决这个问题?我可以只更改Model Signature来处理大量数据吗?
所有工作(加载模型+预测)都是在由对象检测API(来自here)发布的官方容器中进行的