生成器 (ImageDataGenerator) 如何耗尽数据?

时间:2021-05-19 13:10:13

标签: python tensorflow keras

让我们从一个包含 1000 张图片的文件夹开始。

现在,如果我们使用无生成器batch_size = 10steps_per_epoch = 100,我们将使用每张图片作为10 * 100 = 1000。因此增加 steps_per_epoche 将(理所当然地)导致错误:

<块引用>

tensorflow:你的输入数据用完了;中断训练。确保您的数据集或生成器可以生成至少 steps_per_epoch * epochs 个批次(在本例中为 10000 个批次)

另一方面,使用生成器将导致无限批量的图像:

datagenerator = ImageDataGenerator(
    rescale=1./255,
    shear_range=0.1,
    zoom_range=0.1,
    # ...
)

imageFlow = datagenerator.flow_from_directory(
        image_dir_with_1000_pcs,
        target_size=(150, 150),
        batch_size=10,
        class_mode='binary')

i = 0
for x, y in imageFlow:
  print(x.shape) # batch of images
  
  i += 1
  if i > 3000:
    break # I break, because it ENDLESSLY goes on otherwise

但是,如果我去跑

history = model.fit(
      imageFlow,
      steps_per_epoch=101, # I increased this above 100!
      epochs=5,
      #...
)

我会得到同样的错误:为什么? model.fit() 得到一个生成器,因此是无穷无尽的批次。当被无限批量输入时,它怎么会耗尽数据?

在发布这个问题之前,我阅读了:

1 个答案:

答案 0 :(得分:1)

<块引用>

生成器 (ImageDataGenerator) 怎么会耗尽数据?

据我所知,它从生成器创建了一个 tf.data.Dataset,它不会无限运行,这就是您在拟合时看到这种行为的原因。

如果是无限数据集,则您必须指定steps_per_epoch

编辑:如果不指定 steps_per_epoch,则训练将在 number_of_batches >= len(dataset) // batch_size 时停止。它在每个时代都完成。

要检查幕后真正发生的事情,您可以检查 the source。可以看出,创建了一个 tf.data.Dataset 并实际处理批处理和纪元迭代。