我使用tf.estimator.Estimator类来训练一个神经网络,使用TFRecordDataset为它提供图像和标签。
这一切都很好。
现在我试图通过网络摄像头的实时帧进行预测。我为它做了另一个输入功能,它工作正常,但它真的很慢。在训练阶段,我要经历约100张图像/秒。下面的推理代码每帧需要几秒钟。
我在思考,也许我应该修改我的输入功能,以便它以某种方式返回一个从实时反馈中一帧又一帧地返回的永无止境的数据集。
对此的任何意见都表示赞赏。
这就是我现在所拥有的:
def live_input_fn():
ret, frame = video_capture.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = tf.image.convert_image_dtype(frame, dtype=tf.float32)
frame = tf.image.per_image_standardization(frame)
frame = tf.reshape(frame, [1, 120, 160, 3])
dataset = tf.contrib.data.Dataset.from_tensor_slices(frame)
iterator = dataset.make_one_shot_iterator()
features = iterator.get_next()
return features, None
toy_classifier = tf.estimator.Estimator(
model_fn=model_fn, model_dir=script_dir + "/cnn_tflayers")
while True:
predictions = toy_classifier.predict(
input_fn=live_input_fn
)
for p in predictions:
print(label_names[p["classes"]])
print(p["probabilities"])