我正在使用Keras调整一个带有一些附加层的ResNet50模型。 我需要知道每批训练哪些图像。
我遇到的问题是只有imagedata及其标签,但在fit和fit_generator中没有传递图像名称,以便将批量训练的图像名称输出到文件中。
答案 0 :(得分:1)
您可以创建自己的生成器,以便跟踪输入网络的内容,并根据数据执行任何操作(即将索引与图像匹配)。
以下是您可以构建的生成器函数的基本示例:
def gen_data():
x_train = np.random.rand(100, 784)
y_train = np.random.randint(0, 1, 100)
i = 0
while True:
indices = np.arange(i*10, 10*i+10)
# Those are indices being fed to network which can be saved to a file.
print(indices)
out = x_train[indices], y_train[indices]
i = (i+1) % 10
yield out
然后将fit_generator
与新定义的生成器函数一起使用:
model.fit_generator(gen_data(), steps_per_epoch=10, epochs=20)