我有一个Keras模型,我用numpy数组作为X_train来输入,而y_train类似于['0', '1', ...]
,带有字符串。
当我输入模型时,我得到了上面的错误,但是只有使用自定义图像生成器才有这个问题,而如果我使用ImageDataGenerator keras类,一切都很好。您对我做错了什么建议吗?
没关系:
aug = ImageDataGenerator(
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.15,
horizontal_flip=True,
fill_mode="nearest")
def genOK(aug):
(x, y) = next(aug.flow(np.array(x),y,batch_size=batch_size))
yield x, y
model.fit_generator(genOK,...)
这不是:
my_data_augmentation_generator(X,y):
while True:
do some image transformations here
yield X_transformed, y
my_aug = my_data_augmentation_generator()
def genNOTOK(my_aug):
(x, y) = next(my_aug)
yield x, y
model.fit_generator(genNOTOK,...)