tensorflow 2.0,model.fit():您的输入数据不足

时间:2020-02-14 15:15:24

标签: python tensorflow keras neural-network

我绝对不是TensorFlow和Keras的新手,我正在尝试尝试一些我在网上找到的代码。

特别是我正在使用fashion-MNIST-由60000个示例和10000个示例组成的测试集。它们每个都是28x28灰度图像。

我正在学习本教程“ enter image description here”,并且在定义

之前没有问题。
officeId

据我所知,我需要将 train_dataset.repeat()用作输入数据集,因为否则我将没有足够的训练示例使用这些值作为超参数(历元,steps_per_epochs)。

我的问题是:如何避免不得不使用 .repeat()? 我该如何更改超参数?

为简单起见,我在这里处理代码:

history = model.fit(
train_dataset.repeat(), 
epochs=10, 
steps_per_epoch=500,
validation_data=val_dataset.repeat(), 
validation_steps=2)

谢谢!

2 个答案:

答案 0 :(得分:3)

如果您不想使用.repeat(),则需要让模型传递一次,以为每个时期仅一次整个数据。

要做到这一点,您需要计算模型通过整个数据集所要执行的步骤,计算起来很容易:

steps_per_epoch = len(train_dataset) // batch_size

因此,如果train_dataset为60000个样本,batch_size为128,则每个时期需要468个步骤。

通过像这样设置此参数,确保不超过数据集的大小。

答案 1 :(得分:0)

我遇到了同样的问题,这就是我所发现的。 tf.keras.Model.fit 的文档:“如果x是tf.data数据集,并且'steps_per_epoch'为None,则该纪元将一直运行直到输入数据集用尽。”

换句话说,如果我们使用tf.data.dataset作为训练数据,则不需要需要指定“ steps_per_epoch”,并且tf会计算出其中有多少步。同时,tf将在下一个纪元开始时自动重复数据集,因此您可以指定任何“纪元”。

传递无限重复的数据集(例如,dataset.repeat())时,必须指定steps_per_epoch参数。