保存Keras增强数据作为一个numpy数组

时间:2017-08-10 11:46:34

标签: python numpy tensorflow keras

使用keras ImageDataGenerator , 我们可以将增强图像保存为png或jpg:

    for X_batch, y_batch in datagen.flow(train_data, train_labels, batch_size=batch_size,\
                save_to_dir='images', save_prefix='aug', save_format='png'):

我有一个形状的数据集(1600,4,100,100),这意味着1600个图像有4个100x100像素的通道。如何将增强数据保存为numpy数组形状(N,4,100,100)而不是单个图像?

1 个答案:

答案 0 :(得分:0)

由于您知道样本数量= 1600,因此只要达到此数量,您就可以停止datagen.flow()

augmented_data = []
num_augmented = 0
for X_batch, y_batch in datagen.flow(train_data, train_labels, batch_size=batch_size, shuffle=False):
    augmented_data.append(X_batch)
    num_augmented += batch_size
    if num_augmented == train_data.shape[0]:
        break
augmented_data = np.concatenate(augmented_data)
np.save(...)

请注意,您应该正确设置batch_size(例如batch_size=10),以便不会生成额外的增强图片。