使用下面的Keras Python代码:
for x_batch,y_batch in datagen.flow_from_directory(
directory = os.path.join(dataset_root_path,dataset_train_path),
target_size = (520,520),
class_mode = 'binary',
batch_size = 1
):
我得到了x_batch和y_batch numpy数组,y_batch numpy数组被编码为数字0.0或1.0,因为我使用的是“二进制”class_mode,但是,这样,我丢失了关于什么是真实标签的信息对于该样本,例如,“cat”或“dog”。如何根据输出“1.0”和“0.0”检索标签信息?
答案 0 :(得分:0)
我建议实例化datagenerator并使用insert ='INSERT INTO {0} VALUES("{1}","{2}")'.format(tab,item1,item2)
:
fit_generator
然后,您可以使用train_gen = datagen.flow_from_directory( ... )
model.fit_generator(train_gen, ...)
访问(以及其他属性)类索引。