keras_unet.utils.get_augmented从磁盘读取

时间:2020-11-10 13:09:20

标签: python-3.x keras input

我想在大型数据集(10K图像和蒙版)上训练cnn模型。

目前,我正在批量读取BATCH_SIZE = 500张图片中的数据,并以

进行扩充
def get_augmented
    return get_augmented(
        x_train, y_train, batch_size=BATCH_SIZE,
        data_gen_args=dict(
            rotation_range=5.,
            width_shift_range=0.05,
            height_shift_range=0.05,
            shear_range=40,
            zoom_range=0.2,
            horizontal_flip=True,
            vertical_flip=False,
            fill_mode='constant'
        ))
    

主循环看起来像

STEPS_PER_EPOCH = 2
INNER_EPOCHS = 2
EPOCHS = 10

model = init_model(model_filename)  # Define the model
for epoch in range(EPOCHS):  # Number of times I want to iterate over the full dataset
    for batch_id in range(0,20):  # number of chunks of 500 images
    
        # reads images and masks from disk from batch_id folder
        x_train, y_train, x_val, y_val = read_a_butch_of_masks_and_images(batch_id)   
        
        history = model.fit(
            train_gen,
            steps_per_epoch=STEPS_PER_EPOCH,
            epochs=INNER_EPOCHS,
            validation_data=(x_val, y_val),
            callbacks=[callback_checkpoint]
        )

model.save_weights(new_model_filename)

如何告诉get_augmented从磁盘读取下一批图像,而不是消耗x_train, y_train,以便它消耗很少的内存并遍历所有图像?

0 个答案:

没有答案
相关问题