输入数据不足;中断训练

时间:2020-05-20 16:30:16

标签: python tensorflow machine-learning keras

我正在尝试使用TensorFlow v2.1上的图像数据增强生成器来训练包含160个图像(80辆汽车,80个平面)的数据集的模型。运行以下代码时出现错误:

classifier.compile(optimizer='adam', loss='binary_crossentropy',
                   metrics=['accuracy'])
from keras.preprocessing.image import ImageDataGenerator
train_imagedata = ImageDataGenerator(rescale=1. / 255, shear_range=0.2,
        zoom_range=0.2, horizontal_flip=True)
test_imagedata = ImageDataGenerator(rescale=1. / 255)
training_set = \
    train_imagedata.flow_from_directory('data/training_set'
        , target_size=(64, 64), batch_size=32, class_mode='binary')
val_set = \
    test_imagedata.flow_from_directory('data/val_set'
        , target_size=(64, 64), batch_size=32, class_mode='binary')
history=classifier.fit(training_set, steps_per_epoch=30, epochs=30,
                         validation_data=val_set,
                         validation_steps=30)

错误是:

Found 160 images belonging to 2 classes.
Found 40 images belonging to 2 classes.
Epoch 1/30
 5/30 [====>.........................] - ETA: 5s - loss: 0.5002 - accuracy: 0.8313WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 900 batches). You may need to use the repeat() function when building your dataset.
WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 30 batches). You may need to use the repeat() function when building your dataset.
 5/30 [====>.........................] - 2s 416ms/step - loss: 0.5002 - accuracy: 0.8313 - val_loss: 1.6599 - val_accuracy: 0.5000

请提出在此处可以采取哪些措施来纠正此错误?预先感谢!

1 个答案:

答案 0 :(得分:2)

ImageDataGenerator的默认batch_size为32。您要求每个时期执行30个步骤,这意味着每个时期执行30 * 32张图像,但是您只有160张图像,因此在5批处理之后,训练会崩溃。您需要将steps_per_epoch设置为floor(num_of_images / batch_size)