在Keras ResNet50模型中获取每批图像的名称

时间:2018-04-19 17:15:20

标签: python tensorflow keras resnet

我正在使用Keras调整一个带有一些附加层的ResNet50模型。 我需要知道每批训练哪些图像。

我遇到的问题是只有imagedata及其标签,但在fit和fit_generator中没有传递图像名称,以便将批量训练的图像名称输出到文件中。

1 个答案:

答案 0 :(得分:1)

您可以创建自己的生成器,以便跟踪输入网络的内容,并根据数据执行任何操作(即将索引与图像匹配)。

以下是您可以构建的生成器函数的基本示例:

def gen_data():
    x_train = np.random.rand(100, 784)
    y_train = np.random.randint(0, 1, 100)

    i = 0
    while True:  
        indices = np.arange(i*10, 10*i+10)

        # Those are indices being fed to network which can be saved to a file.
        print(indices)
        out = x_train[indices],  y_train[indices]
        i = (i+1) % 10
        yield out

然后将fit_generator与新定义的生成器函数一起使用:

model.fit_generator(gen_data(), steps_per_epoch=10, epochs=20)