如何在keras flow_from_directory中手动指定类标签?

时间:2017-03-29 07:02:58

标签: python image-processing deep-learning keras multilabel-classification

问题:我正在培训多标签图像识别模型。因此,我的图像与多个y标签相关联。这与ImageDataGenerator的方便keras方法“flow_from_directory”相冲突,其中每个图像应该位于相应标签(https://keras.io/preprocessing/image/)的文件夹中。

解决方法:目前,我正在将所有图像读入一个numpy数组并从那里使用“flow”功能。但这会导致大量内存负载和缓慢的读入过程。

问题:有没有办法使用“flow_from_directory”方法并手动提供(多个)类标签?

更新:我最终扩展了多标签案例的DirectoryIterator类。您现在可以将属性“class_mode”设置为值“multilabel”,并提供字典“multlabel_classes”,它将文件名映射到其标签。代码:https://github.com/tholor/keras/commit/29ceafca3c4792cb480829c5768510e4bdb489c5

3 个答案:

答案 0 :(得分:8)

您可以简单地使用flow_from_directory并以下列方式将其扩展为多类:

def multiclass_flow_from_directory(flow_from_directory_gen, multiclasses_getter):
    for x, y in flow_from_directory_gen:
        yield x, multiclasses_getter(x, y)

multiclasses_getter在哪里为您的图像分配多类矢量/您的多类表示。请注意,xy不是单个示例,而是批量示例,因此应将其包含在multiclasses_getter设计中。

答案 1 :(得分:2)

您可以编写一个自定义生成器类,该类将从目录中读取文件并应用标签。该自定义生成器还可以接收ImageDataGenerator实例,该实例将使用flow()生成批处理。

我想象的是这样的事情:

class Generator():

    def __init__(self, X, Y, img_data_gen, batch_size):
        self.X = X
        self.Y = Y  # Maybe a file that has the appropriate label mapping?
        self.img_data_gen = img_data_gen  # The ImageDataGenerator Instance
        self.batch_size = batch_size

    def apply_labels(self):
        # Code to apply labels to each sample based on self.X and self.Y

    def get_next_batch(self):
        """Get the next training batch"""
        self.img_data_gen.flow(self.X, self.Y, self.batch_size)

然后简单地说:

img_gen = ImageDataGenerator(...)
gen = Generator(X, Y, img_gen, 128)

model.fit_generator(gen.get_next_batch(), ...)

*免责声明:我实际上没有对此进行测试,但它应该在理论上有用。

答案 2 :(得分:0)

# Training the model
history = model.fit(train_generator, steps_per_epoch=steps_per_epoch, epochs=3, validation_data=val_generator,validation_steps=validation_steps, verbose=1,
                    callbacks= keras.callbacks.ModelCheckpoint(filepath='/content/results',monitor='val_accuracy', save_best_only=True,save_weights_only=False))

validation_stepssteps_per_epoch可能超过原始参数。

steps_per_epoch= (int(num_of_training_examples/batch_size)可能会有所帮助。 同样,validation_steps= (int(num_of_val_examples/batch_size)会有所帮助