为什么ImageDataGenerator会永远迭代?

时间:2019-05-21 06:33:51

标签: python tensorflow keras

我刚开始使用Keras并进行了一些图像预处理,因此我发现从ImageDataGenerator接收的生成器正在for-loop中无限期地迭代。

image_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255, rotation_range=45)

train_data_gen = image_gen.flow_from_directory(train_dir,
                                               shuffle=True,
                                              target_size=(IMG_SHAPE, IMG_SHAPE),
                                              batch_size=batch_size
                                              )
print('Total number of batches - {}'.format(len(train_data_gen)))
for n, i in enumerate(train_data_gen):
    if n >= 30:
        # I have to add explicit break statement to get out of loop when done with iterating over all the items present in generator.
        break
    batch_data = i[0]
    print(n, batch_data[0].shape)
# TRY to access element out of bound to see if there really exists more than 30 elements.
print(''.format(train_data_gen[32]))

输出

Found 2935 images belonging to 5 classes.
Total number of batches - 30
0 (150, 150, 3)
1 (150, 150, 3)
2 (150, 150, 3)
.
.
.
29 (150, 150, 3)
---------------------------------------------------------------------------
ValueError: Traceback (most recent call last)
<ipython-input-20-aed377bb98f7> in <module>
     13     batch_data = i[0]
     14     print(n, batch_data[0].shape)
---> 15 print(''.format(train_data_gen[32]))

~/.virtualenvs/pan_demo/lib/python3.6/site-packages/keras_preprocessing/image/iterator.py in __getitem__(self, idx)
     55                              'but the Sequence '
     56                              'has length {length}'.format(idx=idx,
---> 57                                                           length=len(self)))
     58         if self.seed is not None:
     59             np.random.seed(self.seed + self.total_batches_seen)

ValueError: Asked to retrieve element 32, but the Sequence has length 30

问题

  1. ImageDataGenerator是这样工作的吗?如果是这样,我可以避免以某种方式{@ {1}}检查零件吗?
  2. 在准备导致这种行为的发电机时,我是否错过了某些东西?

Keras版本:if n >=30 ---> tf.keras.__version__ Tensorflow版本:2.2.4-tf ---> tf.VERSION

2 个答案:

答案 0 :(得分:2)

实际上,train_data_gen将无限批量生成数据。

调用model.fit_generator()时,我们将train_data_gen指定为生成器,并设置steps_per_epoch(应为len(train_data)/batch_size)。然后模型将知道何时完成一个单一时期。

答案 1 :(得分:1)

来自documentation

for e in range(epochs):
    print('Epoch', e)
    batches = 0
    for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
        model.fit(x_batch, y_batch)
        batches += 1
        if batches >= len(x_train) / 32:
            # we need to break the loop by hand because
            # the generator loops indefinitely
            break