使用Tensorflow Data API和Keras模型预测批次

时间:2019-10-16 23:39:54

标签: tensorflow keras tensorflow-datasets

假设我有一个数据集和一个Keras模型。已使用tf Dataset API中的ObservableList<A> list = FXCollections.observableArrayList(item -> new Observable[] {item.durationProperty}); list.addListener((InvalidationListener) observable -> { //Update you sum here }); 将数据集分为几批。现在,我正在寻找一种有效且干净的方法来对所有测试样品进行批量预测。

我尝试了以下代码,并且可以正常工作。

batch()

我想知道有没有更有效,更优雅的方法来实现这一目标?

1 个答案:

答案 0 :(得分:0)

TF> = 1.14.0

您只需设置steps=None。根据{{​​1}}的官方文档:

  

如果x是tf.data数据集,而steps为None,则预测将运行直到输入数据集用尽。

只需确保您的tf.keras.Model.predict()对象未处于重复模式,就可以了:)。

TF 1.12.0和1.13.0

在这些版本中,对datasettf.data.Dataset的支持非常差。 tf.keras对象将转换为迭代器here,如果您未设置tf.data.Dataset参数,则它将引发错误here。在1.14.0中对此进行了修补。