如何在使用Keras的数据生成器时检索标签信息?

时间:2017-09-23 08:08:28

标签: python keras

使用下面的Keras Python代码:

for x_batch,y_batch in datagen.flow_from_directory(
    directory = os.path.join(dataset_root_path,dataset_train_path),
    target_size = (520,520),
    class_mode = 'binary',
    batch_size = 1
):

我得到了x_batch和y_batch numpy数组,y_batch numpy数组被编码为数字0.0或1.0,因为我使用的是“二进制”class_mode,但是,这样,我丢失了关于什么是真实标签的信息对于该样本,例如,“cat”或“dog”。如何根据输出“1.0”和“0.0”检索标签信息?

1 个答案:

答案 0 :(得分:0)

我建议实例化datagenerator并使用insert ='INSERT INTO {0} VALUES("{1}","{2}")'.format(tab,item1,item2)

进行训练
fit_generator

然后,您可以使用train_gen = datagen.flow_from_directory( ... ) model.fit_generator(train_gen, ...) 访问(以及其他属性)类索引。