使用ImageDataGenerator进行无限循环

时间:2020-07-10 14:06:56

标签: tensorflow for-loop machine-learning keras infinite-loop

我正在使用更大的数据集,因此,必须批量将数据加载到RAM中,以便在不耗尽资源的情况下更快地运行。我在.flow

中使用Image Data Generator

使用for循环会导致一个无限循环,该循环会不断生成相同批处理大小的图像,然后再循环播放以重新开始。准备代码如下所示:

train_dataset=tf.keras.preprocessing.image.ImageDataGenerator(featurewise_center=False, samplewise_center=False,
    featurewise_std_normalization=False, samplewise_std_normalization=False,
    zca_whitening=False, rotation_range=0, width_shift_range=0.0,
    height_shift_range=0.0, brightness_range=None, shear_range=0.0, zoom_range=0.0,
    channel_shift_range=0.0, cval=0.0, horizontal_flip=False,
    vertical_flip=False, preprocessing_function=None,
    data_format=None, validation_split=0.0, dtype=None)
train_dataset.fit(X)

接着是如下所示的尝试循环:

for images, y_batch in train_dataset.flow(X, y, batch_size=batch_size):
          print(np.shape(images))

代码只保留返回的维度数组:

(batch_size,img_size,img_size,3) (我需要这些图像来将数据带入我的RAM中以执行反向道具)。请注意,我没有使用model.fit之类的东西,需要通过我的正确代码运行这些数组。

不太确定如何添加停止条件

1 个答案:

答案 0 :(得分:1)

这就是重点;永远继续迭代。 Keras的model.fit_gerentaor()或tf.keras的model.fit()句柄根据epochssteps_per_epoch参数终止训练循环。

如果要手动使用ImageDataGenerator()训练模型,则可以大致执行以下操作:

epochs = 10
steps_per_epoch = len(x) // batch_size + 1  # we usually consider 1 epoch to be
                                            # the point where the model has seen
                                            # all the training samples at least once

generator = train_dataset.flow(X, y, batch_size=batch_size)

for e in range(epochs):
    for i, (images, y_batch) in enumerate(generator):
       model.train_on_batch(images, y_batch)  # train model for a single iteration
       if i >= steps_per_epoch:  # manually detect the end of the epoch
           break  
    generator.on_epoch_end()  # this shuffles the data at the end of each epoch