如何在keras中使用flow_from_dataframe输入多个图像?

时间:2019-06-27 17:50:15

标签: python dataframe keras

我一直在尝试创建暹罗模型,以查找2张图片(有2张输入图片)之间的图片相似度。一开始,我使用一个小的数据集对其进行了测试,并将其安装在我的RAM中,并且运行良好。现在,我想增加训练样本的大小,并为此创建了 images.csv 文件。在此文件中,我有3列: image_1image_2similarity

image_1 image_2 是图像的绝对路径。 similarity是0或1。

我尝试过

generator.flow_from_dataframe(dataframe, target_size=(64, 64, 1), x_col=['image_1', 'image_2'],
                                                    y_col='similarity',
                                                    class_mode='sparse', subset='training')

但出现此错误:

  

ValueError:x_col = ['image_1','image_2']列中的所有值都必须是   字符串。

删除 image_2 并出现 x_col=image_1 错误后,

消失了,但只有1张输入图像。

我该怎么办?

2 个答案:

答案 0 :(得分:1)

您不能使用该方法从单个发生器生成两个图像,该方法只能处理documentation中的一个图像:

  

x_col:字符串,数据帧中包含文件名的列(如果目录为None,则为绝对路径)。

相反,您可以做的是创建两个生成器,并更适当地允许您的网络具有两个输入

in1 = generator.flow_from_dataframe(dataframe, target_size=(64, 64, 1), x_col='image_1',
                                                    y_col='similarity',
                                                    class_mode='sparse', subset='training')

in2 = generator.flow_from_dataframe(dataframe, target_size=(64, 64, 1), x_col='image_2',
                                                    y_col='similarity',
                                                    class_mode='sparse', subset='training')

然后使用functional API构建一个接受两个图像输入的模型:

input_image1 = Input(shape=(64, 64, 1))
input_image2 = Input(shape=(64, 64, 1))
# ... all other layers to create output_layer
model = Model([input_image1, input_image2], output)
# ...

这更能反映出您的模型实际上有2个输入作为图像。

答案 1 :(得分:0)

借助@nuric,我能够输入多个图像。这是创建流程的完整代码:

def get_flow_from_dataframe(generator, dataframe,
                            image_shape=(64, 64),
                            subset='training',
                            color_mode='grayscale', batch_size=64):
    train_generator_1 = generator.flow_from_dataframe(dataframe, target_size=image_shape,
                                                      color_mode=color_mode,
                                                      x_col='image_1',
                                                      y_col='prediction',
                                                      class_mode='binary',
                                                      shuffle=True,
                                                      batch_size=batch_size,
                                                      seed=7,
                                                      subset=subset, drop_duplicates=False)

    train_generator_2 = generator.flow_from_dataframe(dataframe, target_size=image_shape,
                                                      color_mode=color_mode,
                                                      x_col='image_2',
                                                      y_col='prediction',
                                                      class_mode='binary',
                                                      shuffle=True,
                                                      batch_size=batch_size,
                                                      seed=7,
                                                      subset=subset, drop_duplicates=False)
    while True:
        x_1 = train_generator_1.next()
        x_2 = train_generator_2.next()

        yield [x_1[0], x_2[0]], x_1[1]

fit_generator的完整代码:

train_gen = get_flow_from_dataframe(generator, dataframe, image_shape=(64, 64),
                                        color_mode='rgb',
                                        batch_size=batch_size)
valid_gen = get_flow_from_dataframe(generator, dataframe, image_shape=(64, 64),
                                        color_mode='rgb',
                                        batch_size=batch_size,
                                        subset='validation')

model.fit_generator(train_gen, epochs=50,
                        steps_per_epoch=step_size,
                        validation_data=valid_gen,
                        validation_steps=step_size,
                        callbacks=get_call_backs('../models/model_1.h5', monitor='val_acc'),
                        )

我也看到内存消耗巨大。